Maharshi Gor commited on
Commit
5f3e7d5
·
1 Parent(s): 5d637a7

Made workflows a submodule

Browse files
.gitignore CHANGED
@@ -16,8 +16,8 @@ __pycache__/
16
  *ipynb
17
  .vscode/
18
 
19
- eval-queue/
20
- eval-results/
21
- eval-queue-bk/
22
- eval-results-bk/
23
  logs/
 
 
 
 
16
  *ipynb
17
  .vscode/
18
 
19
+ eval-*/
 
 
 
20
  logs/
21
+ data/
22
+ outputs/
23
+ hf_cache/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "shared/workflows"]
2
+ path = shared/workflows
3
+ url = https://github.com/qanta-challenge/ai_workflows
app.py CHANGED
@@ -30,8 +30,8 @@ from envs import (
30
  RESULTS_REPO,
31
  SERVER_REFRESH_INTERVAL,
32
  )
33
- from workflows import factory
34
- from workflows.configs import AVAILABLE_MODELS
35
 
36
 
37
  def restart_space():
 
30
  RESULTS_REPO,
31
  SERVER_REFRESH_INTERVAL,
32
  )
33
+ from shared.workflows import factory
34
+ from shared.workflows.configs import AVAILABLE_MODELS
35
 
36
 
37
  def restart_space():
check_repos.py DELETED
@@ -1,26 +0,0 @@
1
- from huggingface_hub import HfApi
2
-
3
- from src.envs import LLM_CACHE_REPO, QUEUE_REPO, RESULTS_REPO, TOKEN
4
-
5
-
6
- def check_and_create_dataset_repo(repo_id: str):
7
- api = HfApi(token=TOKEN)
8
- try:
9
- api.repo_info(repo_id=repo_id, repo_type="dataset")
10
- print(f"{repo_id} exists")
11
- except Exception:
12
- print(f"Creating {repo_id}")
13
- api.create_repo(repo_id=repo_id, repo_type="dataset", exist_ok=True, private=True)
14
-
15
-
16
- def check_and_create_repos():
17
- print("1. QUEUE Repository")
18
- check_and_create_dataset_repo(QUEUE_REPO)
19
- print("2. RESULTS Repository")
20
- check_and_create_dataset_repo(RESULTS_REPO)
21
- print("3. LLM Cache Repository")
22
- check_and_create_dataset_repo(LLM_CACHE_REPO)
23
-
24
-
25
- if __name__ == "__main__":
26
- check_and_create_repos()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
shared/__init__.py ADDED
File without changes
shared/workflows ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 9d8bfae31f4db8b25165c950742d13e6c4e80de8
src/components/model_pipeline/model_pipeline.py CHANGED
@@ -12,8 +12,8 @@ from components.model_pipeline.state_manager import (
12
  from components.model_step.model_step import ModelStepComponent
13
  from components.structs import ModelStepUIState, PipelineState, PipelineUIState
14
  from components.utils import make_state
15
- from workflows.structs import ModelStep, Workflow
16
- from workflows.validators import WorkflowValidationError, WorkflowValidator
17
 
18
  from .state_manager import get_output_panel_state
19
 
 
12
  from components.model_step.model_step import ModelStepComponent
13
  from components.structs import ModelStepUIState, PipelineState, PipelineUIState
14
  from components.utils import make_state
15
+ from shared.workflows.structs import ModelStep, Workflow
16
+ from shared.workflows.validators import WorkflowValidationError, WorkflowValidator
17
 
18
  from .state_manager import get_output_panel_state
19
 
src/components/model_pipeline/state_manager.py CHANGED
@@ -13,8 +13,8 @@ from components import typed_dicts as td
13
  from components import utils
14
  from components.structs import ModelStepUIState, PipelineState, PipelineUIState, TossupPipelineState
15
  from envs import DOCS_REPO_BRANCH, DOCS_REPO_URL
16
- from workflows.factory import create_new_llm_step
17
- from workflows.structs import Buzzer, BuzzerMethod, ModelStep, TossupWorkflow, Workflow
18
 
19
 
20
  def get_output_panel_state(workflow: Workflow) -> dict:
@@ -113,9 +113,7 @@ class PipelineStateManager:
113
  pipeline_change = not pipeline_change
114
  return new_state_dict, pipeline_change
115
 
116
- def move_down(
117
- self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int
118
- ) -> td.PipelineStateDict:
119
  """Move a step down in the pipeline."""
120
  new_state_dict, change = self._move_step(state_dict, position, "down")
121
  if change:
@@ -189,7 +187,9 @@ class PipelineStateManager:
189
  help_text = f"Refer to the <a href='{repo_files_url}/pipeline-schema.md' target='_blank'>documentation</a> for the correct pipeline schema."
190
  else:
191
  error_type = "Unexpected Error"
192
- help_text = f"Please report this issue to us at <a href='{DOCS_REPO_URL}/issues' target='_blank'>GitHub Issues</a>."
 
 
193
 
194
  return error_template.format(error_type=error_type, error_message=str(e), help_text=help_text)
195
 
 
13
  from components import utils
14
  from components.structs import ModelStepUIState, PipelineState, PipelineUIState, TossupPipelineState
15
  from envs import DOCS_REPO_BRANCH, DOCS_REPO_URL
16
+ from shared.workflows.factory import create_new_llm_step
17
+ from shared.workflows.structs import Buzzer, BuzzerMethod, ModelStep, TossupWorkflow, Workflow
18
 
19
 
20
  def get_output_panel_state(workflow: Workflow) -> dict:
 
113
  pipeline_change = not pipeline_change
114
  return new_state_dict, pipeline_change
115
 
116
+ def move_down(self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int) -> td.PipelineStateDict:
 
 
117
  """Move a step down in the pipeline."""
118
  new_state_dict, change = self._move_step(state_dict, position, "down")
119
  if change:
 
187
  help_text = f"Refer to the <a href='{repo_files_url}/pipeline-schema.md' target='_blank'>documentation</a> for the correct pipeline schema."
188
  else:
189
  error_type = "Unexpected Error"
190
+ help_text = (
191
+ f"Please report this issue to us at <a href='{DOCS_REPO_URL}/issues' target='_blank'>GitHub Issues</a>."
192
+ )
193
 
194
  return error_template.format(error_type=error_type, error_message=str(e), help_text=help_text)
195
 
src/components/model_pipeline/tossup_pipeline.py CHANGED
@@ -6,15 +6,13 @@ from components import commons
6
  from components.structs import PipelineUIState, TossupPipelineState
7
  from components.typed_dicts import TossupPipelineStateDict
8
  from display.formatting import tiny_styled_warning
9
- from workflows.structs import Buzzer, TossupWorkflow
10
 
11
  from .model_pipeline import PipelineInterface
12
  from .state_manager import BasePipelineValidator, TossupPipelineStateManager
13
 
14
 
15
- def toggleable_slider(
16
- value, minimum, maximum, step, toggle_value=False, label=None, info=None, min_width=200, scale=1
17
- ):
18
  with gr.Column(elem_classes="toggleable", min_width=min_width, scale=scale):
19
  show_label = label is not None
20
  checkbox = gr.Checkbox(label=label, value=toggle_value, container=False, info=info, show_label=show_label)
@@ -90,9 +88,7 @@ class TossupPipelineInterface(PipelineInterface):
90
  ),
91
  )
92
 
93
- def _render_buzzer_panel(
94
- self, buzzer: Buzzer, prob_slider_supported: bool, selected_model_name: str | None = None
95
- ):
96
  with gr.Row(elem_classes="control-panel"):
97
  self.confidence_slider = gr.Slider(
98
  minimum=0.0,
 
6
  from components.structs import PipelineUIState, TossupPipelineState
7
  from components.typed_dicts import TossupPipelineStateDict
8
  from display.formatting import tiny_styled_warning
9
+ from shared.workflows.structs import Buzzer, TossupWorkflow
10
 
11
  from .model_pipeline import PipelineInterface
12
  from .state_manager import BasePipelineValidator, TossupPipelineStateManager
13
 
14
 
15
+ def toggleable_slider(value, minimum, maximum, step, toggle_value=False, label=None, info=None, min_width=200, scale=1):
 
 
16
  with gr.Column(elem_classes="toggleable", min_width=min_width, scale=scale):
17
  show_label = label is not None
18
  checkbox = gr.Checkbox(label=label, value=toggle_value, container=False, info=info, show_label=show_label)
 
88
  ),
89
  )
90
 
91
+ def _render_buzzer_panel(self, buzzer: Buzzer, prob_slider_supported: bool, selected_model_name: str | None = None):
 
 
92
  with gr.Row(elem_classes="control-panel"):
93
  self.confidence_slider = gr.Slider(
94
  minimum=0.0,
src/components/model_step/model_step.py CHANGED
@@ -7,8 +7,8 @@ from gradio.components import FormComponent
7
  from app_configs import UNSELECTED_VAR_NAME
8
  from components.model_pipeline.state_manager import ModelStepUIState, PipelineStateManager
9
  from components.typed_dicts import PipelineStateDict
 
10
  from utils import get_full_model_name
11
- from workflows.structs import ModelStep
12
 
13
  from .state_manager import ModelStepStateManager
14
  from .ui_components import InputRowButtonGroup, OutputRowButtonGroup
 
7
  from app_configs import UNSELECTED_VAR_NAME
8
  from components.model_pipeline.state_manager import ModelStepUIState, PipelineStateManager
9
  from components.typed_dicts import PipelineStateDict
10
+ from shared.workflows.structs import ModelStep
11
  from utils import get_full_model_name
 
12
 
13
  from .state_manager import ModelStepStateManager
14
  from .ui_components import InputRowButtonGroup, OutputRowButtonGroup
src/components/model_step/state_manager.py CHANGED
@@ -6,8 +6,8 @@ from loguru import logger
6
  from app_configs import UNSELECTED_VAR_NAME
7
  from components.model_pipeline.state_manager import ModelStepUIState
8
  from components.utils import DIRECTIONS, move_item
 
9
  from utils import get_model_and_provider
10
- from workflows.structs import FieldType, ModelStep
11
 
12
 
13
  class ModelStepStateManager:
 
6
  from app_configs import UNSELECTED_VAR_NAME
7
  from components.model_pipeline.state_manager import ModelStepUIState
8
  from components.utils import DIRECTIONS, move_item
9
+ from shared.workflows.structs import FieldType, ModelStep
10
  from utils import get_model_and_provider
 
11
 
12
 
13
  class ModelStepStateManager:
src/components/quizbowl/bonus.py CHANGED
@@ -12,9 +12,9 @@ from components import commons
12
  from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState
13
  from components.typed_dicts import PipelineStateDict
14
  from display.formatting import styled_error
 
 
15
  from submission import submit
16
- from workflows import factory
17
- from workflows.qb_agents import QuizBowlBonusAgent
18
 
19
  from . import populate, validation
20
  from .plotting import create_bonus_confidence_plot, create_bonus_html
 
12
  from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState
13
  from components.typed_dicts import PipelineStateDict
14
  from display.formatting import styled_error
15
+ from shared.workflows import factory
16
+ from shared.workflows.qb_agents import QuizBowlBonusAgent
17
  from submission import submit
 
 
18
 
19
  from . import populate, validation
20
  from .plotting import create_bonus_confidence_plot, create_bonus_html
src/components/quizbowl/tossup.py CHANGED
@@ -12,9 +12,9 @@ from components import commons
12
  from components.model_pipeline.tossup_pipeline import TossupPipelineInterface, TossupPipelineState
13
  from components.typed_dicts import TossupInterfaceDefaults, TossupPipelineStateDict
14
  from display.formatting import styled_error
 
 
15
  from submission import submit
16
- from workflows import factory
17
- from workflows.qb_agents import QuizBowlTossupAgent, TossupResult
18
 
19
  from . import populate, validation
20
  from .plotting import (
 
12
  from components.model_pipeline.tossup_pipeline import TossupPipelineInterface, TossupPipelineState
13
  from components.typed_dicts import TossupInterfaceDefaults, TossupPipelineStateDict
14
  from display.formatting import styled_error
15
+ from shared.workflows import factory
16
+ from shared.workflows.qb_agents import QuizBowlTossupAgent, TossupResult
17
  from submission import submit
 
 
18
 
19
  from . import populate, validation
20
  from .plotting import (
src/components/quizbowl/validation.py CHANGED
@@ -3,8 +3,8 @@ from typing import Literal
3
  from app_configs import AVAILABLE_MODELS, CONFIGS
4
  from components.structs import PipelineState, TossupPipelineState
5
  from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict
6
- from workflows.structs import TossupWorkflow, Workflow
7
- from workflows.validators import WorkflowValidationError, WorkflowValidator
8
 
9
 
10
  def validate_workflow(
 
3
  from app_configs import AVAILABLE_MODELS, CONFIGS
4
  from components.structs import PipelineState, TossupPipelineState
5
  from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict
6
+ from shared.workflows.structs import TossupWorkflow, Workflow
7
+ from shared.workflows.validators import WorkflowValidationError, WorkflowValidator
8
 
9
 
10
  def validate_workflow(
src/components/structs.py CHANGED
@@ -2,7 +2,7 @@ from typing import Any, Literal
2
 
3
  from pydantic import BaseModel, Field, model_validator
4
 
5
- from workflows.structs import ModelStep, TossupWorkflow, Workflow
6
 
7
 
8
  def make_step_id(step_number: int):
 
2
 
3
  from pydantic import BaseModel, Field, model_validator
4
 
5
+ from shared.workflows.structs import ModelStep, TossupWorkflow, Workflow
6
 
7
 
8
  def make_step_id(step_number: int):
src/components/typed_dicts.py CHANGED
@@ -1,6 +1,6 @@
1
  from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
2
 
3
- from workflows.structs import TossupWorkflow, Workflow
4
 
5
 
6
  # TypedDicts for workflows/structs.py
 
1
  from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
2
 
3
+ from shared.workflows.structs import TossupWorkflow, Workflow
4
 
5
 
6
  # TypedDicts for workflows/structs.py
src/populate.py CHANGED
@@ -49,7 +49,7 @@ def get_tossups_leaderboard_df(repo_dir: str, eval_split: str) -> pd.DataFrame:
49
  row["Win Rate w/ Human (Aggressive)"] = metrics["human_win_rate_strict"]
50
  eval_results.append(row)
51
  except Exception as e:
52
- logger.error(f"Error processing model result: {e}")
53
  continue
54
 
55
  return pd.DataFrame(eval_results)
@@ -72,7 +72,7 @@ def get_bonuses_leaderboard_df(repo_dir: str, eval_split: str) -> pd.DataFrame:
72
  }
73
  eval_results.append(row)
74
  except Exception as e:
75
- logger.error(f"Error processing model result: {e}")
76
  continue
77
 
78
  return pd.DataFrame(eval_results)
@@ -96,9 +96,7 @@ def get_evaluation_queue_df(save_path: str, cols: list) -> list[pd.DataFrame]:
96
  all_evals.append(data)
97
  elif ".md" not in entry:
98
  # this is a folder
99
- sub_entries = [
100
- e for e in os.listdir(f"{save_path}/{entry}") if os.path.isfile(e) and not e.startswith(".")
101
- ]
102
  for sub_entry in sub_entries:
103
  file_path = os.path.join(save_path, entry, sub_entry)
104
  with open(file_path) as fp:
 
49
  row["Win Rate w/ Human (Aggressive)"] = metrics["human_win_rate_strict"]
50
  eval_results.append(row)
51
  except Exception as e:
52
+ logger.error(f"Error processing model result '{username}/{model_name}': {e}")
53
  continue
54
 
55
  return pd.DataFrame(eval_results)
 
72
  }
73
  eval_results.append(row)
74
  except Exception as e:
75
+ logger.error(f"Error processing model result '{username}/{model_name}': {e}")
76
  continue
77
 
78
  return pd.DataFrame(eval_results)
 
96
  all_evals.append(data)
97
  elif ".md" not in entry:
98
  # this is a folder
99
+ sub_entries = [e for e in os.listdir(f"{save_path}/{entry}") if os.path.isfile(e) and not e.startswith(".")]
 
 
100
  for sub_entry in sub_entries:
101
  file_path = os.path.join(save_path, entry, sub_entry)
102
  with open(file_path) as fp:
src/submission/_submit.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from datetime import datetime, timezone
4
+
5
+ from src.display.formatting import styled_error, styled_message, styled_warning
6
+ from src.envs import API, EVAL_REQUESTS_PATH, TOKEN, QUEUE_REPO
7
+ from src.submission.check_validity import (
8
+ already_submitted_models,
9
+ check_model_card,
10
+ get_model_size,
11
+ is_model_on_hub,
12
+ )
13
+
14
+ REQUESTED_MODELS = None
15
+ USERS_TO_SUBMISSION_DATES = None
16
+
17
+ def add_new_eval(
18
+ model: str,
19
+ base_model: str,
20
+ revision: str,
21
+ precision: str,
22
+ weight_type: str,
23
+ model_type: str,
24
+ ):
25
+ global REQUESTED_MODELS
26
+ global USERS_TO_SUBMISSION_DATES
27
+ if not REQUESTED_MODELS:
28
+ REQUESTED_MODELS, USERS_TO_SUBMISSION_DATES = already_submitted_models(EVAL_REQUESTS_PATH)
29
+
30
+ user_name = ""
31
+ model_path = model
32
+ if "/" in model:
33
+ user_name = model.split("/")[0]
34
+ model_path = model.split("/")[1]
35
+
36
+ precision = precision.split(" ")[0]
37
+ current_time = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
38
+
39
+ if model_type is None or model_type == "":
40
+ return styled_error("Please select a model type.")
41
+
42
+ # Does the model actually exist?
43
+ if revision == "":
44
+ revision = "main"
45
+
46
+ # Is the model on the hub?
47
+ if weight_type in ["Delta", "Adapter"]:
48
+ base_model_on_hub, error, _ = is_model_on_hub(model_name=base_model, revision=revision, token=TOKEN, test_tokenizer=True)
49
+ if not base_model_on_hub:
50
+ return styled_error(f'Base model "{base_model}" {error}')
51
+
52
+ if not weight_type == "Adapter":
53
+ model_on_hub, error, _ = is_model_on_hub(model_name=model, revision=revision, token=TOKEN, test_tokenizer=True)
54
+ if not model_on_hub:
55
+ return styled_error(f'Model "{model}" {error}')
56
+
57
+ # Is the model info correctly filled?
58
+ try:
59
+ model_info = API.model_info(repo_id=model, revision=revision)
60
+ except Exception:
61
+ return styled_error("Could not get your model information. Please fill it up properly.")
62
+
63
+ model_size = get_model_size(model_info=model_info, precision=precision)
64
+
65
+ # Were the model card and license filled?
66
+ try:
67
+ license = model_info.cardData["license"]
68
+ except Exception:
69
+ return styled_error("Please select a license for your model")
70
+
71
+ modelcard_OK, error_msg = check_model_card(model)
72
+ if not modelcard_OK:
73
+ return styled_error(error_msg)
74
+
75
+ # Seems good, creating the eval
76
+ print("Adding new eval")
77
+
78
+ eval_entry = {
79
+ "model": model,
80
+ "base_model": base_model,
81
+ "revision": revision,
82
+ "precision": precision,
83
+ "weight_type": weight_type,
84
+ "status": "PENDING",
85
+ "submitted_time": current_time,
86
+ "model_type": model_type,
87
+ "likes": model_info.likes,
88
+ "params": model_size,
89
+ "license": license,
90
+ "private": False,
91
+ }
92
+
93
+ # Check for duplicate submission
94
+ if f"{model}_{revision}_{precision}" in REQUESTED_MODELS:
95
+ return styled_warning("This model has been already submitted.")
96
+
97
+ print("Creating eval file")
98
+ OUT_DIR = f"{EVAL_REQUESTS_PATH}/{user_name}"
99
+ os.makedirs(OUT_DIR, exist_ok=True)
100
+ out_path = f"{OUT_DIR}/{model_path}_eval_request_False_{precision}_{weight_type}.json"
101
+
102
+ with open(out_path, "w") as f:
103
+ f.write(json.dumps(eval_entry))
104
+
105
+ print("Uploading eval file")
106
+ API.upload_file(
107
+ path_or_fileobj=out_path,
108
+ path_in_repo=out_path.split("eval-queue/")[1],
109
+ repo_id=QUEUE_REPO,
110
+ repo_type="dataset",
111
+ commit_message=f"Add {model} to eval queue",
112
+ )
113
+
114
+ # Remove the local file
115
+ os.remove(out_path)
116
+
117
+ return styled_message(
118
+ "Your request has been submitted to the evaluation queue!\nPlease wait for up to an hour for the model to show in the PENDING list."
119
+ )
src/submission/check_validity.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ from collections import defaultdict
5
+ from datetime import datetime, timedelta, timezone
6
+
7
+ import huggingface_hub
8
+ from huggingface_hub import ModelCard
9
+ from huggingface_hub.hf_api import ModelInfo
10
+ from transformers import AutoConfig
11
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
12
+
13
+ def check_model_card(repo_id: str) -> tuple[bool, str]:
14
+ """Checks if the model card and license exist and have been filled"""
15
+ try:
16
+ card = ModelCard.load(repo_id)
17
+ except huggingface_hub.utils.EntryNotFoundError:
18
+ return False, "Please add a model card to your model to explain how you trained/fine-tuned it."
19
+
20
+ # Enforce license metadata
21
+ if card.data.license is None:
22
+ if not ("license_name" in card.data and "license_link" in card.data):
23
+ return False, (
24
+ "License not found. Please add a license to your model card using the `license` metadata or a"
25
+ " `license_name`/`license_link` pair."
26
+ )
27
+
28
+ # Enforce card content
29
+ if len(card.text) < 200:
30
+ return False, "Please add a description to your model card, it is too short."
31
+
32
+ return True, ""
33
+
34
+ def is_model_on_hub(model_name: str, revision: str, token: str = None, trust_remote_code=False, test_tokenizer=False) -> tuple[bool, str]:
35
+ """Checks if the model model_name is on the hub, and whether it (and its tokenizer) can be loaded with AutoClasses."""
36
+ try:
37
+ config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=trust_remote_code, token=token)
38
+ if test_tokenizer:
39
+ try:
40
+ tk = AutoTokenizer.from_pretrained(model_name, revision=revision, trust_remote_code=trust_remote_code, token=token)
41
+ except ValueError as e:
42
+ return (
43
+ False,
44
+ f"uses a tokenizer which is not in a transformers release: {e}",
45
+ None
46
+ )
47
+ except Exception as e:
48
+ return (False, "'s tokenizer cannot be loaded. Is your tokenizer class in a stable transformers release, and correctly configured?", None)
49
+ return True, None, config
50
+
51
+ except ValueError:
52
+ return (
53
+ False,
54
+ "needs to be launched with `trust_remote_code=True`. For safety reason, we do not allow these models to be automatically submitted to the leaderboard.",
55
+ None
56
+ )
57
+
58
+ except Exception as e:
59
+ return False, "was not found on hub!", None
60
+
61
+
62
+ def get_model_size(model_info: ModelInfo, precision: str):
63
+ """Gets the model size from the configuration, or the model name if the configuration does not contain the information."""
64
+ try:
65
+ model_size = round(model_info.safetensors["total"] / 1e9, 3)
66
+ except (AttributeError, TypeError):
67
+ return 0 # Unknown model sizes are indicated as 0, see NUMERIC_INTERVALS in app.py
68
+
69
+ size_factor = 8 if (precision == "GPTQ" or "gptq" in model_info.modelId.lower()) else 1
70
+ model_size = size_factor * model_size
71
+ return model_size
72
+
73
+ def get_model_arch(model_info: ModelInfo):
74
+ """Gets the model architecture from the configuration"""
75
+ return model_info.config.get("architectures", "Unknown")
76
+
77
+ def already_submitted_models(requested_models_dir: str) -> set[str]:
78
+ """Gather a list of already submitted models to avoid duplicates"""
79
+ depth = 1
80
+ file_names = []
81
+ users_to_submission_dates = defaultdict(list)
82
+
83
+ for root, _, files in os.walk(requested_models_dir):
84
+ current_depth = root.count(os.sep) - requested_models_dir.count(os.sep)
85
+ if current_depth == depth:
86
+ for file in files:
87
+ if not file.endswith(".json"):
88
+ continue
89
+ with open(os.path.join(root, file), "r") as f:
90
+ info = json.load(f)
91
+ file_names.append(f"{info['model']}_{info['revision']}_{info['precision']}")
92
+
93
+ # Select organisation
94
+ if info["model"].count("/") == 0 or "submitted_time" not in info:
95
+ continue
96
+ organisation, _ = info["model"].split("/")
97
+ users_to_submission_dates[organisation].append(info["submitted_time"])
98
+
99
+ return set(file_names), users_to_submission_dates
src/submission/structs.py CHANGED
@@ -3,7 +3,7 @@ from typing import Dict, List, Literal, Optional
3
 
4
  from pydantic import BaseModel, Field
5
 
6
- from workflows.structs import TossupWorkflow, Workflow
7
 
8
  CompetitionType = Literal["tossup", "bonus"]
9
  SubmissionType = Literal["python_file", "simple_workflow", "complex_workflow"]
 
3
 
4
  from pydantic import BaseModel, Field
5
 
6
+ from shared.workflows.structs import TossupWorkflow, Workflow
7
 
8
  CompetitionType = Literal["tossup", "bonus"]
9
  SubmissionType = Literal["python_file", "simple_workflow", "complex_workflow"]
src/submission/submit.py CHANGED
@@ -13,8 +13,8 @@ from loguru import logger
13
  from app_configs import DAILY_SUBMISSION_LIMIT_PER_USER
14
  from display.formatting import styled_error, styled_message
15
  from envs import API, EVAL_REQUESTS_PATH, EXAMPLES_PATH, OWNER, QUEUE_REPO
 
16
  from submission.structs import CompetitionType, Submission, SubmissionStatus
17
- from workflows.structs import TossupWorkflow, Workflow
18
 
19
 
20
  def get_user_submissions(username: str, competition_type: str, pattern: str = None) -> list[Submission]:
@@ -238,7 +238,7 @@ def load_submission(model_name: str, competition_type: CompetitionType, profile:
238
 
239
  if __name__ == "__main__":
240
  # Example usage
241
- from workflows.factory import create_quizbowl_simple_step_initial_setup
242
 
243
  # Create workflow
244
  model_step = create_quizbowl_simple_step_initial_setup()
 
13
  from app_configs import DAILY_SUBMISSION_LIMIT_PER_USER
14
  from display.formatting import styled_error, styled_message
15
  from envs import API, EVAL_REQUESTS_PATH, EXAMPLES_PATH, OWNER, QUEUE_REPO
16
+ from shared.workflows.structs import TossupWorkflow, Workflow
17
  from submission.structs import CompetitionType, Submission, SubmissionStatus
 
18
 
19
 
20
  def get_user_submissions(username: str, competition_type: str, pattern: str = None) -> list[Submission]:
 
238
 
239
  if __name__ == "__main__":
240
  # Example usage
241
+ from shared.workflows.factory import create_quizbowl_simple_step_initial_setup
242
 
243
  # Create workflow
244
  model_step = create_quizbowl_simple_step_initial_setup()
src/workflows/README.md DELETED
@@ -1,129 +0,0 @@
1
- # Workflows Subpackage
2
-
3
- This subpackage provides a framework for defining, validating, and executing workflows composed of interconnected model steps with dependency management.
4
-
5
- ## Overview
6
-
7
- The workflows subpackage enables the creation and execution of workflows where multiple model steps can be combined, with outputs from earlier steps feeding into inputs of later steps. The package handles dependency resolution, execution order, and error handling.
8
-
9
- ## Components
10
-
11
- ### `structs.py`
12
-
13
- Contains the core data structures used throughout the workflow system:
14
-
15
- - `InputField`: Represents an input field with name, description, and variable reference
16
- - `OutputField`: Represents an output field with name, type, and description
17
- - `ModelStep`: Represents a single step in a workflow with input fields, output fields, and model details
18
- - `Workflow`: A collection of ModelSteps with their identifiers
19
- - `TossupWorkflow`: Specialized workflow for quizbowl tossup questions with buzzing capability
20
-
21
- ### `configs.py`
22
-
23
- Provides configuration settings and constants:
24
-
25
- - `AVAILABLE_MODELS`: Supported model configurations from various providers
26
- - `TYPE_MAP`: Mapping of supported field types to Python types
27
- - `FUNCTION_MAP`: Built-in transformation functions for input/output processing
28
-
29
- ### `utils.py`
30
-
31
- Provides utility functions for workflow operations:
32
-
33
- - `create_dependency_graph`: Builds a dependency graph representing the execution order constraints
34
- - `topological_sort`: Sorts steps in execution order based on their dependencies
35
- - `detect_cycles`: Identifies cyclic dependencies in workflow definitions
36
-
37
- ### `executors.py`
38
-
39
- Handles the execution of workflows:
40
-
41
- - `execute_model_step`: Executes a single model step with input processing and output collection
42
- - `execute_simple_workflow`: Handles single-step workflows
43
- - `execute_multi_step_workflow`: Manages multi-step workflows with dependency resolution
44
- - `execute_workflow`: Main entry point that routes to appropriate executor based on workflow complexity
45
-
46
- ### `validators.py`
47
-
48
- Provides workflow validation functionality:
49
-
50
- - `ValidationErrorType`: Enumeration of possible validation error types
51
- - `WorkflowValidationError`: Base class for validation errors
52
- - Validation functions for steps, DAGs, variables, and types
53
-
54
- ### `errors.py`
55
-
56
- Defines custom exceptions for workflow-related errors:
57
-
58
- - `WorkflowError`: Base class for workflow errors
59
- - `CyclicDependencyError`: Raised when detecting cycles in the workflow graph
60
- - `UnknownVariableError`: Raised when a step requires a variable that's not provided or produced
61
-
62
- ## Usage Example
63
-
64
- ```python
65
- from workflows.structs import InputField, ModelStep, OutputField, Workflow
66
-
67
- # Define a workflow with two steps
68
- step1 = ModelStep(
69
- id="step1",
70
- model="gpt-4o-mini",
71
- provider="OpenAI",
72
- call_type="llm",
73
- system_prompt="Step1 processing",
74
- input_fields=[InputField(name="value", description="Input value", variable="input.value")],
75
- output_fields=[OutputField(name="result", description="Processed result", type="str", func="upper")],
76
- )
77
-
78
- step2 = ModelStep(
79
- id="step2",
80
- model="gpt-4o-mini",
81
- provider="OpenAI",
82
- call_type="llm",
83
- system_prompt="Step2 processing",
84
- input_fields=[InputField(name="result", description="Result from step1", variable="step1.result")],
85
- output_fields=[OutputField(name="final", description="Final output", type="str", func="lower")],
86
- )
87
-
88
- workflow = Workflow(
89
- steps={"step1": step1, "step2": step2},
90
- inputs=["input.value"],
91
- outputs={"final": "step2.final"}
92
- )
93
-
94
- # Execute the workflow
95
- from workflows.executors import execute_workflow
96
-
97
- result = execute_workflow(
98
- workflow=workflow,
99
- input_values={"input.value": "Hello, World!"},
100
- return_full_content=True,
101
- logprob_step="step2"
102
- )
103
-
104
- # Access results
105
- final_output = result["final_outputs"]["final"]
106
- intermediate_results = result["intermediate_outputs"]
107
- step_contents = result["step_contents"]
108
- logprob = result["logprob"]
109
- ```
110
-
111
- ## Error Handling
112
-
113
- The workflows system provides robust error handling:
114
-
115
- - Detects cyclic dependencies in workflow definitions
116
- - Validates input/output variable references
117
- - Ensures all required inputs are provided
118
- - Supports custom validation rules through the validation system
119
- - Provides detailed error messages for debugging
120
-
121
- ## Extending the Workflows System
122
-
123
- To extend the workflows system:
124
-
125
- 1. Add new model step types by extending the `ModelStep` class
126
- 2. Create custom field types by extending validation in the execution logic
127
- 3. Implement additional error types in `errors.py` for specialized error handling
128
- 4. Add new transformation functions to `FUNCTION_MAP` in `configs.py`
129
- 5. Create specialized workflow types by extending the `Workflow` class
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/workflows/configs.py DELETED
@@ -1,56 +0,0 @@
1
- """
2
- Configuration settings for the workflows package.
3
-
4
- This module contains configuration settings and constants used across the workflows package,
5
- including model configurations, workflow settings, and other package-wide constants.
6
- """
7
-
8
- AVAILABLE_MODELS = {
9
- "OpenAI/gpt-4o": {
10
- "model": "gpt-4o-2024-11-20",
11
- "logprobs": True,
12
- },
13
- "OpenAI/gpt-4o-mini": {
14
- "model": "gpt-4o-mini-2024-07-18",
15
- "logprobs": True,
16
- },
17
- "OpenAI/gpt-3.5-turbo": {
18
- "model": "gpt-3.5-turbo-0125",
19
- },
20
- "Anthropic/claude-3-7-sonnet": {
21
- "model": "claude-3-7-sonnet-20250219",
22
- },
23
- "Anthropic/claude-3-5-sonnet": {
24
- "model": "claude-3-5-sonnet-20241022",
25
- },
26
- "Anthropic/claude-3-5-haiku": {
27
- "model": "claude-3-5-haiku-20241022",
28
- },
29
- "Cohere/command-r": {
30
- "model": "command-r-08-2024",
31
- "logprobs": True,
32
- },
33
- "Cohere/command-r-plus": {
34
- "model": "command-r-plus-08-2024",
35
- "logprobs": True,
36
- },
37
- "Cohere/command-r7b": {
38
- "model": "command-r7b-12-2024",
39
- "logprobs": False,
40
- },
41
- }
42
-
43
- # Function mapping for input/output transformations
44
- TYPE_MAP = {
45
- "str": str,
46
- "int": int,
47
- "float": float,
48
- "bool": bool,
49
- }
50
-
51
- FUNCTION_MAP = {
52
- "upper": str.upper,
53
- "lower": str.lower,
54
- "len": len,
55
- "split": str.split,
56
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/workflows/errors.py DELETED
@@ -1,63 +0,0 @@
1
- """
2
- Custom exceptions for workflow validation and execution errors.
3
-
4
- This module defines the exception hierarchy for the workflows package, enabling
5
- specific error types to be raised and caught during workflow validation and execution.
6
- Each exception provides detailed error messages to help diagnose and fix issues in
7
- workflow definitions or execution.
8
-
9
- Exception hierarchy:
10
- - WorkflowError (base class)
11
- - UnknownVariableError (missing variable reference)
12
- - CyclicDependencyError (circular dependencies)
13
- - FunctionNotFoundError (missing function reference)
14
- """
15
-
16
-
17
- # Define custom exceptions for workflow errors
18
- class WorkflowError(Exception):
19
- """
20
- Base exception class for all workflow-related errors.
21
-
22
- This is the parent class for all workflow-specific exceptions and can be used
23
- to catch any error from the workflows package.
24
- """
25
-
26
- pass
27
-
28
-
29
- class UnknownVariableError(WorkflowError):
30
- """
31
- Raised when a workflow step references a variable that doesn't exist.
32
-
33
- This typically occurs when a step's input field references a variable that is neither
34
- provided as an external input nor produced as an output by any previous step.
35
- """
36
-
37
- def __init__(self, var: str):
38
- super().__init__(f"Unknown variable referenced: {var}")
39
-
40
-
41
- class CyclicDependencyError(WorkflowError):
42
- """
43
- Raised when a cyclic dependency is detected in a workflow.
44
-
45
- A cyclic dependency occurs when there is a circular reference in the workflow graph,
46
- such as step A depending on step B, which depends on step A. Such workflows cannot
47
- be executed because there's no valid order to process the steps.
48
- """
49
-
50
- def __init__(self):
51
- super().__init__("Cyclic dependency detected in workflow")
52
-
53
-
54
- class FunctionNotFoundError(WorkflowError):
55
- """
56
- Raised when a referenced function cannot be found during workflow execution.
57
-
58
- This typically occurs when a step references a function that doesn't exist in
59
- the available function registry or namespace.
60
- """
61
-
62
- def __init__(self, func_name: str):
63
- super().__init__(f"Function not found: {func_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/workflows/executors.py DELETED
@@ -1,673 +0,0 @@
1
- """
2
- Core workflow execution functionality.
3
-
4
- This module handles the execution of defined workflows, including input processing,
5
- dependency-based execution order, model calling, and output collection. It integrates
6
- with the litellm library to handle model interactions.
7
-
8
- Key components:
9
- - Utility functions for input/output transformation
10
- - Input processing and validation
11
- - Model step execution with support for log probabilities
12
- - Complete workflow execution with dependency resolution
13
- - Support for both simple (single-step) and multi-step workflows
14
- - Structured output collection with intermediate results
15
-
16
- The module orchestrates the execution of steps in the correct order based on their
17
- dependencies and manages the flow of data between steps. It supports:
18
- - Full content tracking for debugging
19
- - Log probability calculation for specific steps
20
- - Flexible input/output transformations
21
- - Error handling and validation
22
- """
23
-
24
- from typing import Any, TypedDict
25
-
26
- import pydantic
27
-
28
- from .configs import FUNCTION_MAP, TYPE_MAP
29
- from .errors import WorkflowError
30
- from .llms import completion
31
- from .structs import InputField, ModelStep, OutputField, Workflow
32
- from .utils import create_dependency_graph, topological_sort
33
-
34
-
35
- def get_type(type_str: str) -> type:
36
- """
37
- Converts a type string to its corresponding Python type.
38
-
39
- This function maps type strings to their actual Python type objects. It first checks
40
- the TYPE_MAP dictionary for predefined mappings, and if not found, falls back to
41
- evaluating the type string directly.
42
-
43
- Args:
44
- type_str (str): A string representation of a type (e.g., "str", "int", "list[str]")
45
-
46
- Returns:
47
- type: The corresponding Python type object
48
-
49
- Note:
50
- Uses eval() for non-predefined types, which has security implications if used
51
- with untrusted input. This is intended for internal use with validated type strings.
52
- """
53
- return TYPE_MAP.get(type_str, eval(type_str))
54
-
55
-
56
- def create_processed_inputs(model_step: ModelStep, available_vars: dict[str, Any]) -> dict[str, Any]:
57
- """
58
- Creates processed inputs for a model step.
59
-
60
- This function extracts and processes the required inputs for a model step based on
61
- its input field definitions. It retrieves values from the available variables dictionary
62
- and applies any specified transformations.
63
-
64
- Args:
65
- model_step (ModelStep): The model step for which to create processed inputs.
66
- available_vars (dict[str, Any]): Dictionary of variables available for use as inputs.
67
- Keys are variable names, values are the variable values.
68
-
69
- Returns:
70
- dict[str, Any]: A dictionary of processed inputs ready for use by the model step.
71
- Keys are input field names, values are the processed input values.
72
-
73
- Raises:
74
- WorkflowError: If a required variable is not found in available_vars,
75
- or if a specified transformation function is not available.
76
-
77
- Example:
78
- >>> available_vars = {"step1.output": "Hello World"}
79
- >>> create_processed_inputs(model_step, available_vars)
80
- {"input_field_name": "HELLO WORLD"} # If upper transformation was specified
81
- """
82
- processed_inputs: dict[str, Any] = {}
83
- for input_field in model_step.input_fields:
84
- var = input_field.variable
85
- value = available_vars[var]
86
- if input_field.func is not None:
87
- func = FUNCTION_MAP.get(input_field.func)
88
- func = func or eval(input_field.func)
89
- value = func(value)
90
- processed_inputs[input_field.name] = value
91
- return processed_inputs
92
-
93
-
94
- class ModelStepResult(TypedDict):
95
- """
96
- Result of executing a model step.
97
-
98
- This TypedDict contains the outputs and metadata from executing a single model step,
99
- including the processed output values, the full response content, and log probability
100
- information when requested.
101
-
102
- Attributes:
103
- outputs (dict[str, Any]): A dictionary of processed outputs from the model step,
104
- with keys matching the output field names.
105
- content (str | None): The full content of the model's response, only populated
106
- if return_full_content is True.
107
- logprob (float | None): The log probability of the model step output, only populated
108
- if logprobs is True.
109
- """
110
-
111
- # A dictionary of processed outputs from the model step,
112
- # with keys matching the output field names.
113
- outputs: dict[str, Any]
114
-
115
- # The full content of the model step.
116
- content: str | None
117
-
118
- # The log probability of the model step output if requested.
119
- logprob: float | None
120
-
121
-
122
- class WorkflowOutput(TypedDict):
123
- """
124
- Result of executing a complete workflow.
125
-
126
- This TypedDict contains the outputs and metadata from executing a workflow,
127
- including final outputs, intermediate values, step contents, and log probabilities.
128
-
129
- Attributes:
130
- final_outputs (dict[str, Any]): The final output values produced by the workflow,
131
- with keys matching the names defined in workflow.outputs.
132
- intermediate_outputs (dict[str, Any]): All computed values during workflow execution,
133
- including both external inputs and outputs from all steps.
134
- step_contents (dict[str, Any]): Full response content for each step, keyed by step ID.
135
- Only populated if return_full_content is True.
136
- logprob (float | None): The log probability of the specified step's output.
137
- Only populated if logprob_step is specified.
138
- """
139
-
140
- # A dictionary of the workflow's outputs, with keys matching the variables defined in workflow.outputs.
141
- final_outputs: dict[str, Any]
142
-
143
- # A dictionary of all computed values during workflow execution, including intermediate results.
144
- intermediate_outputs: dict[str, Any]
145
-
146
- # A dictionary of step contents, only populated if return_full_content is True.
147
- step_contents: dict[str, Any]
148
-
149
- # The log probability of the workflow output if requested.
150
- logprob: float | None
151
-
152
-
153
- # %%
154
- def execute_model_step(
155
- model_step: ModelStep,
156
- available_vars: dict[str, Any],
157
- return_full_content: bool = False,
158
- logprobs: bool = False,
159
- ) -> ModelStepResult:
160
- """
161
- Executes a model step using the provided available variables.
162
-
163
- This function handles the complete execution of a model step, including:
164
- 1. Processing inputs using variable references and transformations
165
- 2. Constructing the appropriate prompt for the model
166
- 3. Calling the model via litellm with structured output
167
- 4. Processing and validating the model's response
168
- 5. Applying any output transformations
169
-
170
- The function supports different providers and model types through the litellm
171
- integration, allowing for a consistent interface regardless of the underlying model.
172
-
173
- Args:
174
- model_step (ModelStep): The model step to execute, containing model details,
175
- input/output specifications, and system prompt.
176
- available_vars (dict[str, Any]): A dictionary of all variables available to this step,
177
- including outputs from previous steps and external inputs.
178
- return_full_content (bool, optional): If True, includes the full model response content
179
- in the result. Defaults to False.
180
- logprobs (bool, optional): If True, calculates and returns log probability information
181
- for the model response. Defaults to False.
182
-
183
- Returns:
184
- ModelStepResult: A TypedDict containing processed outputs, optional full content,
185
- and optional log probability information.
186
-
187
- Raises:
188
- WorkflowError: If there's an error in input processing, model execution,
189
- or output validation.
190
-
191
- Example:
192
- >>> step = ModelStep(
193
- ... id="summarize",
194
- ... model="gpt-3.5-turbo",
195
- ... provider="openai",
196
- ... call_type="llm",
197
- ... system_prompt="Summarize the text",
198
- ... input_fields=[InputField(name="text", variable="input_text", description="Text to summarize")],
199
- ... output_fields=[OutputField(name="summary", type="str", description="Summary of the text")]
200
- ... )
201
- >>> result = execute_model_step(step, {"input_text": "Long text to be summarized..."})
202
- >>> summary = result["outputs"]["summary"]
203
- """
204
- # Ensure inputs are processed using the specified functions in input_fields.
205
- processed_inputs = create_processed_inputs(model_step, available_vars)
206
-
207
- # Construct the input prompt for the model
208
- input_str = "\n".join(f"{k}: {v}" for k, v in processed_inputs.items())
209
- step_result = f"Inputs: \n{input_str}"
210
-
211
- # Define the expected output fields and their types
212
- fields = {
213
- field.name: (get_type(field.type), pydantic.Field(..., description=field.description))
214
- for field in model_step.output_fields
215
- }
216
- ModelResponse = pydantic.create_model("ModelResponse", **fields)
217
-
218
- # Execute the model step using litellm
219
- api_response = completion(
220
- model=f"{model_step.provider}/{model_step.model}",
221
- system=model_step.system_prompt,
222
- prompt=step_result,
223
- response_format=ModelResponse,
224
- temperature=model_step.temperature,
225
- logprobs=logprobs,
226
- )
227
-
228
- # Map the parsed response to the output fields
229
- outputs = {field.name: api_response["output"][field.name] for field in model_step.output_fields}
230
- result = ModelStepResult(outputs=outputs, content=None, logprob=None)
231
- if return_full_content:
232
- result["content"] = api_response["content"]
233
- if logprobs:
234
- result["logprob"] = api_response.get("logprob")
235
- return result
236
-
237
-
238
- def execute_multi_step_workflow(
239
- workflow: Workflow,
240
- input_values: dict[str, Any],
241
- return_full_content: bool = False,
242
- logprob_step: str | None = None,
243
- ) -> WorkflowOutput:
244
- """
245
- Execute the given workflow as a computational graph.
246
-
247
- This function orchestrates the complete execution of a workflow by:
248
-
249
- 1. Validating and populating initial values using the provided external inputs
250
- 2. Building a dependency graph between workflow steps
251
- 3. Determining a valid execution order using topological sorting
252
- 4. Executing each step in the correct order, with inputs from previous steps
253
- 5. Collecting and returning the final outputs
254
-
255
- The execution process ensures that all dependencies are satisfied before a step
256
- is executed, and that the data flows correctly between steps according to the
257
- variable references defined in each step's input fields.
258
-
259
- Args:
260
- workflow (Workflow): The workflow to execute, containing steps, their
261
- dependencies, and input/output specifications.
262
- input_values (dict[str, Any]): External input values to be used by the workflow.
263
- Keys should match the required workflow.inputs.
264
- return_full_content (bool, optional): If True, returns the full content of each step.
265
- Defaults to False.
266
- logprob_step (str, optional): The ID of the step to use for log probability calculation.
267
- Defaults to None.
268
-
269
- Returns:
270
- WorkflowOutput: A dictionary of workflow outputs, including final outputs, intermediate outputs, and step contents.
271
-
272
- Raises:
273
- UnknownVariableError: If an input_field references a variable that is not
274
- provided externally nor produced by any step.
275
- CyclicDependencyError: If the workflow contains a circular dependency that
276
- prevents a valid execution order.
277
- FunctionNotFoundError: If a transformation function specified in input_fields.func
278
- or output_fields.func is not available.
279
- WorkflowError: For any other workflow-related errors, such as missing required inputs.
280
-
281
- Example:
282
- >>> workflow = Workflow(
283
- ... steps={
284
- ... "extract": ModelStep(...), # A step that extracts entities
285
- ... "analyze": ModelStep(...) # A step that analyzes the entities
286
- ... },
287
- ... inputs=["text"],
288
- ... outputs={"sentiment": "analyze.sentiment", "entities": "extract.entities"}
289
- ... )
290
- >>> final_outputs, computed_values, step_contents = execute_workflow(workflow, {"text": "Apple is launching a new product tomorrow."})
291
- >>> print(final_outputs["sentiment"])
292
- "positive"
293
- >>> print(final_outputs["entities"])
294
- ["Apple", "product"]
295
- """
296
- # Step 1: Pre-populate computed values with external workflow inputs.
297
- computed_values: dict[str, Any] = {}
298
- for var in workflow.inputs:
299
- if var not in input_values:
300
- raise WorkflowError(f"Missing required workflow input: {var}")
301
- computed_values[var] = input_values[var]
302
-
303
- # Step 2: Build dependency graph among model steps.
304
- # For each step, examine its input_fields. If an input is not in the pre-populated external inputs,
305
- # then it is expected to be produced by some step. Otherwise, raise an error.
306
- dependencies = create_dependency_graph(workflow, input_values)
307
-
308
- # Step 3: Determine the execution order of the steps using topological sort.
309
- # Raises an error if a cycle is detected.
310
- execution_order = topological_sort(dependencies)
311
-
312
- # Step 4: Execute steps in topological order.
313
- step_contents: dict[str, Any] = {}
314
- logprob = None
315
- for step_id in execution_order:
316
- step = workflow.steps[step_id]
317
- return_logprobs = logprob_step == step_id
318
- # Execute the step
319
- result = execute_model_step(
320
- step, computed_values, return_full_content=return_full_content, logprobs=return_logprobs
321
- )
322
- if return_logprobs:
323
- logprob = result["logprob"]
324
- if return_full_content:
325
- step_contents[step_id] = result["content"]
326
- outputs = {f"{step_id}.{k}": v for k, v in result["outputs"].items()}
327
- computed_values.update(outputs)
328
-
329
- # Step 5: Gather and return workflow outputs.
330
- final_outputs: dict[str, Any] = {}
331
- for target, var in workflow.outputs.items():
332
- if var not in computed_values:
333
- raise WorkflowError(
334
- f"Workflow output variable {var} was not produced. Computed values: {computed_values.keys()}"
335
- )
336
- final_outputs[target] = computed_values[var]
337
-
338
- return WorkflowOutput(
339
- final_outputs=final_outputs,
340
- intermediate_outputs=computed_values,
341
- step_contents=step_contents,
342
- logprob=logprob,
343
- )
344
-
345
-
346
- def execute_simple_workflow(
347
- workflow: Workflow,
348
- input_values: dict[str, Any],
349
- return_full_content: bool = False,
350
- logprob_step: bool | str = False,
351
- ) -> WorkflowOutput:
352
- """
353
- Execute a simple workflow with a single step.
354
-
355
- This is an optimized version of workflow execution for workflows containing only one step.
356
- It bypasses the dependency graph building and topological sorting steps, providing a more
357
- direct execution path for simple workflows.
358
-
359
- Args:
360
- workflow (Workflow): The workflow to execute, which must contain exactly one step.
361
- input_values (dict[str, Any]): External input values to be used by the workflow.
362
- Keys should match the required workflow.inputs.
363
- return_full_content (bool, optional): If True, includes the full model response content
364
- in the result. Defaults to False.
365
- logprobs (bool, optional): If True, calculates and returns log probability information
366
- for the model response. Defaults to False.
367
-
368
- Returns:
369
- WorkflowOutput: A TypedDict containing the workflow outputs, intermediate values,
370
- optional step contents, and optional log probability information.
371
-
372
- Raises:
373
- WorkflowError: If the workflow has more than one step or if required inputs are missing.
374
-
375
- Example:
376
- >>> workflow = Workflow(
377
- ... steps={"extract": ModelStep(...)},
378
- ... inputs=["text"],
379
- ... outputs={"entities": "extract.entities"}
380
- ... )
381
- >>> result = execute_simple_workflow(workflow, {"text": "Apple is launching a new product."})
382
- >>> entities = result["final_outputs"]["entities"]
383
- """
384
- if len(workflow.steps) != 1:
385
- raise WorkflowError("Simple workflow must have exactly one step")
386
-
387
- # Get the single step
388
- step = list(workflow.steps.values())[0]
389
-
390
- logprobs = logprob_step is True or logprob_step == step.id
391
-
392
- # Validate inputs
393
- for var in workflow.inputs:
394
- if var not in input_values:
395
- raise WorkflowError(f"Missing required workflow input: {var}")
396
-
397
- # Execute the step
398
- step_result = execute_model_step(step, input_values, return_full_content=return_full_content, logprobs=logprobs)
399
- step_outputs = step_result["outputs"]
400
- step_contents = {step.id: step_result["content"]} if return_full_content else {}
401
- # Prepare the final outputs
402
- final_outputs = {}
403
- for target, var in workflow.outputs.items():
404
- if var.startswith(f"{step.id}."):
405
- output_key = var.split(".", 1)[1]
406
- if output_key in step_outputs:
407
- final_outputs[target] = step_outputs[output_key]
408
- else:
409
- raise WorkflowError(f"Workflow output variable {var} was not produced")
410
- else:
411
- raise WorkflowError(f"Invalid output mapping: {var} does not match step ID {step.id}")
412
-
413
- # Prepare computed values (prefixed with step ID)
414
- computed_values = input_values | {f"{step.id}.{k}": v for k, v in step_outputs.items()}
415
-
416
- return WorkflowOutput(
417
- final_outputs=final_outputs,
418
- intermediate_outputs=computed_values,
419
- step_contents=step_contents,
420
- logprob=step_result.get("logprob"),
421
- )
422
-
423
-
424
- def execute_workflow(
425
- workflow: Workflow,
426
- input_values: dict[str, Any],
427
- return_full_content: bool = False,
428
- logprob_step: str | bool = False,
429
- ) -> WorkflowOutput:
430
- """
431
- Main entry point for executing workflows of any complexity.
432
-
433
- This function serves as a router that delegates to the appropriate specialized
434
- execution function based on the complexity of the workflow:
435
- - For single-step workflows, it calls execute_simple_workflow
436
- - For multi-step workflows, it calls execute_multi_step_workflow
437
-
438
- This abstraction allows callers to use a consistent interface regardless of
439
- the workflow's complexity.
440
-
441
- Args:
442
- workflow (Workflow): The workflow to execute, containing steps, their
443
- dependencies, and input/output specifications.
444
- input_values (dict[str, Any]): External input values to be used by the workflow.
445
- Keys should match the required workflow.inputs.
446
- return_full_content (bool, optional): If True, includes the full model response
447
- content in the result. Defaults to False.
448
- logprob_step (str | bool, optional): Either a string with the ID of the step for which
449
- to calculate log probability, or a boolean flag.
450
- If False, no log probabilities are calculated.
451
- Defaults to False.
452
-
453
- Returns:
454
- WorkflowOutput: A TypedDict containing the workflow outputs, intermediate values,
455
- optional step contents, and optional log probability information.
456
-
457
- Raises:
458
- WorkflowError: For any workflow-related errors, such as missing required inputs,
459
- circular dependencies, or invalid variable references.
460
-
461
- Example:
462
- >>> workflow = Workflow(
463
- ... steps={"extract": ModelStep(...), "analyze": ModelStep(...)},
464
- ... inputs=["text"],
465
- ... outputs={"sentiment": "analyze.sentiment"}
466
- ... )
467
- >>> result = execute_workflow(
468
- ... workflow,
469
- ... {"text": "Apple is launching a new product."},
470
- ... return_full_content=True,
471
- ... logprob_step="analyze"
472
- ... )
473
- >>> print(result["final_outputs"]["sentiment"])
474
- "positive"
475
- """
476
- if len(workflow.steps) > 1:
477
- return execute_multi_step_workflow(workflow, input_values, return_full_content, logprob_step)
478
- else:
479
- return execute_simple_workflow(workflow, input_values, return_full_content, logprob_step)
480
-
481
-
482
- def run_examples():
483
- """
484
- Runs example workflows demonstrating key functionality and error handling.
485
-
486
- This function creates and executes three different example workflows to showcase:
487
-
488
- 1. Successful workflow execution:
489
- - A linear two-step workflow with proper dependency flow
490
- - Input transformation using the 'upper' function
491
- - Output transformation using the 'lower' function
492
- - Proper variable passing between steps
493
-
494
- 2. Cyclic dependency detection:
495
- - A workflow with two steps that depend on each other circularly
496
- - Demonstrates the error handling for cyclic dependencies
497
- - Shows how the system prevents infinite execution loops
498
-
499
- 3. Unknown variable detection:
500
- - A workflow that references a variable not provided as input or by any step
501
- - Demonstrates validation of variable references
502
- - Shows error handling for missing dependencies
503
-
504
- Each example prints its result or the error encountered, making this function
505
- useful for testing and demonstration purposes.
506
-
507
- Returns:
508
- None: This function prints its results and doesn't return a value.
509
- """
510
- print("Example 1: Successful Workflow Execution")
511
- # Example 1: Simple linear workflow.
512
- # External input "input.value" is provided. Two steps:
513
- # - step1 takes "input.value" and produces "step1.result".
514
- # - step2 uses "step1.result" and produces "step2.final".
515
- from workflows.structs import ModelStep, Workflow
516
-
517
- workflow_success = Workflow(
518
- steps={
519
- "step1": ModelStep(
520
- id="step1",
521
- model="gpt-4o-mini",
522
- provider="OpenAI",
523
- call_type="llm",
524
- system_prompt="Step1 processing",
525
- input_fields=[InputField(name="value", description="Input value", variable="input.value")],
526
- output_fields=[OutputField(name="result", description="Processed result", type="str", func="upper")],
527
- ),
528
- "step2": ModelStep(
529
- id="step2",
530
- model="gpt-4o-mini",
531
- provider="OpenAI",
532
- call_type="llm",
533
- system_prompt="Step2 processing",
534
- input_fields=[InputField(name="result", description="Result from step1", variable="step1.result")],
535
- output_fields=[OutputField(name="final", description="Final output", type="str", func="lower")],
536
- ),
537
- },
538
- inputs=["input.value"],
539
- outputs={"final": "step2.final"},
540
- )
541
- input_values_success = {"input.value": "Hello, World!"}
542
- try:
543
- outputs = execute_workflow(workflow_success, input_values_success)
544
- print("Workflow outputs:", outputs)
545
- except WorkflowError as e:
546
- print("Workflow failed with error:", e)
547
-
548
- print("\nExample 2: Cyclic Dependency Workflow")
549
- # Example 2: Cyclic dependency.
550
- # stepA depends on an output from stepB and vice versa.
551
- workflow_cycle = Workflow(
552
- steps={
553
- "stepA": ModelStep(
554
- id="stepA",
555
- model="gpt-4o-mini",
556
- provider="OpenAI",
557
- call_type="llm",
558
- system_prompt="StepA processing",
559
- input_fields=[
560
- InputField(name="input", description="Input from stepB", variable="stepB.output", func="identity")
561
- ],
562
- output_fields=[OutputField(name="output", description="Output from A", type="str", func="upper")],
563
- ),
564
- "stepB": ModelStep(
565
- id="stepB",
566
- model="gpt-4o-mini",
567
- provider="OpenAI",
568
- call_type="llm",
569
- system_prompt="StepB processing",
570
- input_fields=[
571
- InputField(name="input", description="Input from stepA", variable="stepA.output", func="identity")
572
- ],
573
- output_fields=[OutputField(name="output", description="Output from B", type="str", func="upper")],
574
- ),
575
- },
576
- inputs=[], # no external inputs
577
- outputs={"output": "stepB.output"},
578
- )
579
- try:
580
- outputs = execute_workflow(workflow_cycle, {})
581
- print("Workflow outputs:", outputs)
582
- except WorkflowError as e:
583
- print("Workflow failed with error:", e)
584
-
585
- print("\nExample 3: Unknown Variable Dependency Workflow")
586
- # Example 3: A workflow that references a variable not provided as an input or produced by any step.
587
- workflow_unknown = Workflow(
588
- steps={
589
- "stepX": ModelStep(
590
- id="stepX",
591
- model="gpt-4o-mini",
592
- provider="OpenAI",
593
- call_type="llm",
594
- system_prompt="StepX processing",
595
- input_fields=[
596
- InputField(
597
- name="input", description="Non-existent input", variable="nonexistent.value", func="identity"
598
- )
599
- ],
600
- output_fields=[OutputField(name="output", description="Output from X", type="str", func="upper")],
601
- )
602
- },
603
- inputs=[], # no external inputs
604
- outputs={"output": "stepX.output"},
605
- )
606
- try:
607
- outputs = execute_workflow(workflow_unknown, {})
608
- print("Workflow outputs:", outputs)
609
- except WorkflowError as e:
610
- print("Workflow failed with error:", e)
611
-
612
-
613
- if __name__ == "__main__":
614
- # create example of model_step
615
- model_step = ModelStep(
616
- id="step1",
617
- model="gpt-4o-mini",
618
- provider="OpenAI",
619
- call_type="llm",
620
- system_prompt="You are a simple NLP tool that takes a string, and a number N, and return the first N entities in the string, and the total count of entities in the string.",
621
- input_fields=[
622
- InputField(name="sentence", description="The sentence to process", variable="sentence", func=None),
623
- InputField(name="n", description="The number of entities to return", variable="n", func=None),
624
- ],
625
- output_fields=[
626
- OutputField(
627
- name="entities",
628
- description="The first N entities in the string as a list of strings",
629
- type="list[str]",
630
- func=None,
631
- ),
632
- OutputField(name="count", description="The total count of entities in the string", type="int", func=None),
633
- ],
634
- )
635
-
636
- processed_inputs = {"sentence": "Abdul Akbar is a good person, but Jesus is the son of God.", "n": 3}
637
- processed_inputs = create_processed_inputs(model_step, processed_inputs)
638
- print(processed_inputs)
639
-
640
- run_examples()
641
-
642
- # %%
643
-
644
- # Example usage
645
- if __name__ == "__main__":
646
- # Define a simple model step
647
- model_step = ModelStep(
648
- id="step1",
649
- model="gpt-4o-mini",
650
- provider="OpenAI",
651
- call_type="llm",
652
- system_prompt="You are a simple NLP tool that takes a string, and a number N, and return the first N entities in the string, and the total count of entities in the string.",
653
- input_fields=[
654
- InputField(name="sentence", description="The sentence to process", variable="sentence", func=None),
655
- InputField(name="n", description="The number of entities to return", variable="n", func=None),
656
- ],
657
- output_fields=[
658
- OutputField(
659
- name="entities",
660
- description="The first N entities in the string as a list of strings",
661
- type="list[str]",
662
- func=None,
663
- ),
664
- OutputField(name="count", description="The total count of entities in the string", type="int", func=None),
665
- ],
666
- )
667
-
668
- # Define processed inputs
669
- processed_inputs = {"sentence": "Abdul Akbar is a good person, but Jesus is the son of God.", "n": 3}
670
-
671
- # Execute the model step
672
- outputs = execute_model_step(model_step, processed_inputs)
673
- print(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/workflows/factory.py DELETED
@@ -1,176 +0,0 @@
1
- # %%
2
- from .structs import (
3
- Buzzer,
4
- BuzzerMethod,
5
- CallType,
6
- InputField,
7
- ModelStep,
8
- OutputField,
9
- TossupWorkflow,
10
- Workflow,
11
- )
12
-
13
- INITIAL_SYS_PROMPT = """You are a helpful performant question answering bot.
14
- Given a question clue, output your most likely guess in a couple words with a calibrated confidence for the guess.
15
- """
16
-
17
-
18
- def create_empty_bonus_workflow():
19
- return Workflow(
20
- inputs=["leadin", "part"],
21
- outputs={"answer": None, "confidence": None, "explanation": None},
22
- steps={},
23
- )
24
-
25
-
26
- def create_empty_tossup_workflow():
27
- return TossupWorkflow(
28
- inputs=["question_text"],
29
- outputs={"answer": None, "confidence": None},
30
- steps={},
31
- )
32
-
33
-
34
- def create_first_step_input_fields() -> list[InputField]:
35
- return [
36
- InputField(
37
- name="question",
38
- description="The question text progressively revealed to the agent so far.",
39
- variable="question_text",
40
- )
41
- ]
42
-
43
-
44
- def create_empty_input_field() -> list[InputField]:
45
- return [InputField(name="", description="", variable="question_text")]
46
-
47
-
48
- def create_quizbowl_simple_step_initial_setup():
49
- return ModelStep(
50
- id="simple_step",
51
- name="Quizbowl Simple Step",
52
- model="",
53
- provider="",
54
- temperature=0.7,
55
- call_type="llm",
56
- system_prompt=INITIAL_SYS_PROMPT,
57
- input_fields=[
58
- InputField(name="question", description="The question to answer", variable="question"),
59
- ],
60
- output_fields=[
61
- OutputField(name="answer", description="The most likely answer", type="str"),
62
- OutputField(name="confidence", description="The confidence of the answer", type="float"),
63
- ],
64
- )
65
-
66
-
67
- def create_new_llm_step(step_id: str, name: str) -> ModelStep:
68
- return ModelStep(
69
- id=step_id,
70
- name=name,
71
- model="gpt-4o",
72
- provider="OpenAI",
73
- call_type="llm",
74
- temperature=0.7,
75
- system_prompt="",
76
- input_fields=create_empty_input_field(),
77
- output_fields=[OutputField(name="", description="")],
78
- )
79
-
80
-
81
- def create_first_llm_step() -> ModelStep:
82
- return ModelStep(
83
- id="A",
84
- name="",
85
- model="gpt-4o",
86
- provider="OpenAI",
87
- call_type="llm",
88
- temperature=0.7,
89
- system_prompt="",
90
- input_fields=[create_first_step_input_fields()],
91
- output_fields=[OutputField(name="", description="")],
92
- )
93
-
94
-
95
- def create_simple_qb_tossup_workflow():
96
- return TossupWorkflow(
97
- inputs=["question_text"],
98
- outputs={"answer": "A.answer", "confidence": "A.confidence"},
99
- steps={
100
- "A": ModelStep(
101
- id="A",
102
- name="Tossup Agent",
103
- model="gpt-4o-mini",
104
- provider="OpenAI",
105
- call_type="llm",
106
- temperature=0.3,
107
- system_prompt="You are a helpful assistant that can answer questions.",
108
- input_fields=[InputField(name="question", description="The question text", variable="question_text")],
109
- output_fields=[
110
- OutputField(
111
- name="answer",
112
- description="The best guess at the answer to the question",
113
- type="str",
114
- ),
115
- OutputField(
116
- name="confidence",
117
- description="The confidence in the answer, ranging from 0 to 1 in increments of 0.05.",
118
- type="float",
119
- ),
120
- ],
121
- )
122
- },
123
- buzzer=Buzzer(
124
- confidence_threshold=0.75,
125
- prob_threshold=None,
126
- method=BuzzerMethod.AND,
127
- ),
128
- )
129
-
130
-
131
- BONUS_SYS_PROMPT = """You are a quizbowl player answering bonus questions. For each part:
132
- 1. Read the leadin and part carefully
133
- 2. Provide a concise answer
134
- 3. Rate your confidence (0-1)
135
- 4. Explain your reasoning
136
-
137
- Format your response as:
138
- ANSWER: <your answer>
139
- CONFIDENCE: <0-1>
140
- EXPLANATION: <your reasoning>"""
141
-
142
-
143
- def create_simple_qb_bonus_workflow() -> Workflow:
144
- """Create a simple model step for bonus questions."""
145
- return Workflow(
146
- inputs=["leadin", "part"],
147
- outputs={"answer": "A.answer", "confidence": "A.confidence", "explanation": "A.explanation"},
148
- steps={
149
- "A": ModelStep(
150
- id="A",
151
- name="Bonus Agent",
152
- model="gpt-4o-mini",
153
- provider="OpenAI",
154
- temperature=0.3,
155
- call_type=CallType.LLM,
156
- system_prompt=BONUS_SYS_PROMPT,
157
- input_fields=[
158
- InputField(
159
- name="question_leadin",
160
- description="The leadin text for the bonus question",
161
- variable="leadin",
162
- ),
163
- InputField(
164
- name="question_part",
165
- description="The specific part text to answer",
166
- variable="part",
167
- ),
168
- ],
169
- output_fields=[
170
- OutputField(name="answer", description="The predicted answer", type="str"),
171
- OutputField(name="confidence", description="Confidence in the answer (0-1)", type="float"),
172
- OutputField(name="explanation", description="Short explanation for the answer", type="str"),
173
- ],
174
- )
175
- },
176
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/workflows/llmcache.py DELETED
@@ -1,488 +0,0 @@
1
- import hashlib
2
- import json
3
- import os
4
- import sqlite3
5
- import threading
6
- import time
7
- from pathlib import Path
8
- from typing import Any, Optional
9
-
10
- from datasets import Dataset, load_dataset, load_from_disk
11
- from huggingface_hub import snapshot_download
12
- from loguru import logger
13
-
14
-
15
- def load_dataset_from_hf(repo_id, local_dir):
16
- snapshot_download(
17
- repo_id=repo_id,
18
- local_dir=local_dir,
19
- repo_type="dataset",
20
- tqdm_class=None,
21
- etag_timeout=30,
22
- token=os.environ["HF_TOKEN"],
23
- )
24
- return load_dataset(repo_id)
25
-
26
-
27
- class CacheDB:
28
- """Handles database operations for storing and retrieving cache entries."""
29
-
30
- def __init__(self, db_path: Path):
31
- """Initialize database connection.
32
-
33
- Args:
34
- db_path: Path to SQLite database file
35
- """
36
- self.db_path = db_path
37
- self.lock = threading.Lock()
38
-
39
- # Initialize the database
40
- try:
41
- self.initialize_db()
42
- except Exception as e:
43
- logger.exception(f"Failed to initialize database: {e}")
44
- logger.warning(f"Please provide a different filepath or remove the file at {self.db_path}")
45
- raise
46
-
47
- def initialize_db(self) -> None:
48
- """Initialize SQLite database with the required table."""
49
- # Check if database file already exists
50
- if self.db_path.exists():
51
- self._verify_existing_db()
52
- else:
53
- self._create_new_db()
54
-
55
- def _verify_existing_db(self) -> None:
56
- """Verify and repair an existing database if needed."""
57
- try:
58
- with sqlite3.connect(self.db_path) as conn:
59
- cursor = conn.cursor()
60
- self._ensure_table_exists(cursor)
61
- self._verify_schema(cursor)
62
- self._ensure_index_exists(cursor)
63
- conn.commit()
64
- logger.info(f"Using existing SQLite database at {self.db_path}")
65
- except Exception as e:
66
- logger.exception(f"Database corruption detected: {e}")
67
- raise ValueError(f"Corrupted database at {self.db_path}: {str(e)}")
68
-
69
- def _create_new_db(self) -> None:
70
- """Create a new database with the required schema."""
71
- try:
72
- with sqlite3.connect(self.db_path) as conn:
73
- cursor = conn.cursor()
74
- self._create_table(cursor)
75
- self._ensure_index_exists(cursor)
76
- conn.commit()
77
- logger.info(f"Initialized new SQLite database at {self.db_path}")
78
- except Exception as e:
79
- logger.exception(f"Failed to initialize SQLite database: {e}")
80
- raise
81
-
82
- def _ensure_table_exists(self, cursor) -> None:
83
- """Check if the llm_cache table exists and create it if not."""
84
- cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='llm_cache'")
85
- if not cursor.fetchone():
86
- self._create_table(cursor)
87
- logger.info("Created missing llm_cache table")
88
-
89
- def _create_table(self, cursor) -> None:
90
- """Create the llm_cache table with the required schema."""
91
- cursor.execute("""
92
- CREATE TABLE IF NOT EXISTS llm_cache (
93
- key TEXT PRIMARY KEY,
94
- request_json TEXT,
95
- response_json TEXT,
96
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
97
- )
98
- """)
99
-
100
- def _verify_schema(self, cursor) -> None:
101
- """Verify that the table schema has all required columns."""
102
- cursor.execute("PRAGMA table_info(llm_cache)")
103
- columns = {row[1] for row in cursor.fetchall()}
104
- required_columns = {"key", "request_json", "response_json", "created_at"}
105
-
106
- if not required_columns.issubset(columns):
107
- missing = required_columns - columns
108
- raise ValueError(f"Database schema is corrupted. Missing columns: {missing}")
109
-
110
- def _ensure_index_exists(self, cursor) -> None:
111
- """Create an index on the key column for faster lookups."""
112
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_llm_cache_key ON llm_cache (key)")
113
-
114
- def get(self, key: str) -> Optional[dict[str, Any]]:
115
- """Get cached entry by key.
116
-
117
- Args:
118
- key: Cache key to look up
119
-
120
- Returns:
121
- Dict containing the request and response or None if not found
122
- """
123
- try:
124
- with sqlite3.connect(self.db_path) as conn:
125
- conn.row_factory = sqlite3.Row
126
- cursor = conn.cursor()
127
- cursor.execute("SELECT request_json, response_json FROM llm_cache WHERE key = ?", (key,))
128
- result = cursor.fetchone()
129
-
130
- if result:
131
- logger.debug(f"Cache hit for key: {key}. Response: {result['response_json']}")
132
- return {
133
- "request": result["request_json"],
134
- "response": result["response_json"],
135
- }
136
-
137
- logger.debug(f"Cache miss for key: {key}")
138
- return None
139
- except Exception as e:
140
- logger.error(f"Error retrieving from cache: {e}")
141
- return None
142
-
143
- def set(self, key: str, request_json: str, response_json: str) -> bool:
144
- """Set entry in cache.
145
-
146
- Args:
147
- key: Cache key
148
- request_json: JSON string of request parameters
149
- response_json: JSON string of response
150
-
151
- Returns:
152
- True if successful, False otherwise
153
- """
154
- with self.lock:
155
- try:
156
- with sqlite3.connect(self.db_path) as conn:
157
- cursor = conn.cursor()
158
- cursor.execute(
159
- "INSERT OR REPLACE INTO llm_cache (key, request_json, response_json) VALUES (?, ?, ?)",
160
- (key, request_json, response_json),
161
- )
162
- conn.commit()
163
- logger.debug(f"Saved response to cache with key: {key}, response: {response_json}")
164
- return True
165
- except Exception as e:
166
- logger.error(f"Failed to save to SQLite cache: {e}")
167
- return False
168
-
169
- def get_all_entries(self) -> dict[str, dict[str, Any]]:
170
- """Get all cache entries from the database."""
171
- cache = {}
172
- try:
173
- with sqlite3.connect(self.db_path) as conn:
174
- conn.row_factory = sqlite3.Row
175
- cursor = conn.cursor()
176
- cursor.execute("SELECT key, request_json, response_json FROM llm_cache ORDER BY created_at")
177
-
178
- for row in cursor.fetchall():
179
- cache[row["key"]] = {
180
- "request": row["request_json"],
181
- "response": row["response_json"],
182
- }
183
-
184
- logger.debug(f"Retrieved {len(cache)} entries from cache database")
185
- return cache
186
- except Exception as e:
187
- logger.error(f"Error retrieving all cache entries: {e}")
188
- return {}
189
-
190
- def clear(self) -> bool:
191
- """Clear all cache entries.
192
-
193
- Returns:
194
- True if successful, False otherwise
195
- """
196
- with self.lock:
197
- try:
198
- with sqlite3.connect(self.db_path) as conn:
199
- cursor = conn.cursor()
200
- cursor.execute("DELETE FROM llm_cache")
201
- conn.commit()
202
- logger.info("Cache cleared")
203
- return True
204
- except Exception as e:
205
- logger.error(f"Failed to clear cache: {e}")
206
- return False
207
-
208
- def get_existing_keys(self) -> set:
209
- """Get all existing keys in the database.
210
-
211
- Returns:
212
- Set of keys
213
- """
214
- existing_keys = set()
215
- try:
216
- with sqlite3.connect(self.db_path) as conn:
217
- cursor = conn.cursor()
218
- cursor.execute("SELECT key FROM llm_cache")
219
- for row in cursor.fetchall():
220
- existing_keys.add(row[0])
221
- return existing_keys
222
- except Exception as e:
223
- logger.error(f"Error retrieving existing keys: {e}")
224
- return set()
225
-
226
- def bulk_insert(self, items: list, update: bool = False) -> int:
227
- """Insert multiple items into the cache.
228
-
229
- Args:
230
- items: List of (key, request_json, response_json) tuples
231
- update: Whether to update existing entries
232
-
233
- Returns:
234
- Number of items inserted
235
- """
236
- count = 0
237
- UPDATE_OR_IGNORE = "INSERT OR REPLACE" if update else "INSERT OR IGNORE"
238
- with self.lock:
239
- try:
240
- with sqlite3.connect(self.db_path) as conn:
241
- cursor = conn.cursor()
242
- cursor.executemany(
243
- f"{UPDATE_OR_IGNORE} INTO llm_cache (key, request_json, response_json) VALUES (?, ?, ?)",
244
- items,
245
- )
246
- count = cursor.rowcount
247
- conn.commit()
248
- return count
249
- except Exception as e:
250
- logger.error(f"Error during bulk insert: {e}")
251
- return 0
252
-
253
-
254
- class LLMCache:
255
- def __init__(
256
- self, cache_dir: str = ".", hf_repo: str | None = None, cache_sync_interval: int = 3600, reset: bool = False
257
- ):
258
- self.cache_dir = Path(cache_dir)
259
- self.db_path = self.cache_dir / "llm_cache.db"
260
- self.hf_repo_id = hf_repo
261
- self.cache_sync_interval = cache_sync_interval
262
- self.last_sync_time = time.time()
263
-
264
- # Create cache directory if it doesn't exist
265
- self.cache_dir.mkdir(exist_ok=True, parents=True)
266
-
267
- # Initialize CacheDB
268
- self.db = CacheDB(self.db_path)
269
- if reset:
270
- self.db.clear()
271
-
272
- # Try to load from HF dataset if available
273
- try:
274
- self._load_cache_from_hf()
275
- except Exception as e:
276
- logger.warning(f"Failed to load cache from HF dataset: {e}")
277
-
278
- def response_format_to_dict(self, response_format: Any) -> dict[str, Any]:
279
- """Convert a response format to a dict."""
280
- # If it's a Pydantic model, use its schema
281
- if hasattr(response_format, "model_json_schema"):
282
- response_format = response_format.model_json_schema()
283
-
284
- # If it's a Pydantic model, use its dump
285
- elif hasattr(response_format, "model_dump"):
286
- response_format = response_format.model_dump()
287
-
288
- if not isinstance(response_format, dict):
289
- response_format = {"value": str(response_format)}
290
-
291
- return response_format
292
-
293
- def _generate_key(
294
- self, model: str, system: str, prompt: str, response_format: Any, temperature: float | None = None
295
- ) -> str:
296
- """Generate a unique key for caching based on inputs."""
297
- response_format_dict = self.response_format_to_dict(response_format)
298
- response_format_str = json.dumps(response_format_dict, sort_keys=True)
299
- # Include temperature in the key
300
- key_content = f"{model}:{system}:{prompt}:{response_format_str}"
301
- if temperature is not None:
302
- key_content += f":{temperature:.2f}"
303
- return hashlib.md5(key_content.encode()).hexdigest()
304
-
305
- def _create_request_json(
306
- self, model: str, system: str, prompt: str, response_format: Any, temperature: float | None
307
- ) -> str:
308
- """Create JSON string from request parameters."""
309
- logger.info(f"Creating request JSON with temperature: {temperature}")
310
- request_data = {
311
- "model": model,
312
- "system": system,
313
- "prompt": prompt,
314
- "response_format": self.response_format_to_dict(response_format),
315
- "temperature": temperature,
316
- }
317
- return json.dumps(request_data)
318
-
319
- def _check_request_match(
320
- self,
321
- cached_request: dict[str, Any],
322
- model: str,
323
- system: str,
324
- prompt: str,
325
- response_format: Any,
326
- temperature: float | None,
327
- ) -> bool:
328
- """Check if the cached request matches the new request."""
329
- # Check each field and log any mismatches
330
- if cached_request["model"] != model:
331
- logger.debug(f"Cache mismatch: model - cached: {cached_request['model']}, new: {model}")
332
- return False
333
- if cached_request["system"] != system:
334
- logger.debug(f"Cache mismatch: system - cached: {cached_request['system']}, new: {system}")
335
- return False
336
- if cached_request["prompt"] != prompt:
337
- logger.debug(f"Cache mismatch: prompt - cached: {cached_request['prompt']}, new: {prompt}")
338
- return False
339
- response_format_dict = self.response_format_to_dict(response_format)
340
- if cached_request["response_format"] != response_format_dict:
341
- logger.debug(
342
- f"Cache mismatch: response_format - cached: {cached_request['response_format']}, new: {response_format_dict}"
343
- )
344
- return False
345
- if cached_request["temperature"] != temperature:
346
- logger.debug(f"Cache mismatch: temperature - cached: {cached_request['temperature']}, new: {temperature}")
347
- return False
348
-
349
- return True
350
-
351
- def get(
352
- self, model: str, system: str, prompt: str, response_format: dict[str, Any], temperature: float | None = None
353
- ) -> Optional[dict[str, Any]]:
354
- """Get cached response if it exists."""
355
- key = self._generate_key(model, system, prompt, response_format, temperature)
356
- result = self.db.get(key)
357
-
358
- if not result:
359
- return None
360
- request_dict = json.loads(result["request"])
361
- if not self._check_request_match(request_dict, model, system, prompt, response_format, temperature):
362
- logger.warning(f"Cached request does not match new request for key: {key}")
363
- return None
364
-
365
- return json.loads(result["response"])
366
-
367
- def set(
368
- self,
369
- model: str,
370
- system: str,
371
- prompt: str,
372
- response_format: dict[str, Any],
373
- temperature: float | None,
374
- response: dict[str, Any],
375
- ) -> None:
376
- """Set response in cache and sync if needed."""
377
- key = self._generate_key(model, system, prompt, response_format, temperature)
378
- request_json = self._create_request_json(model, system, prompt, response_format, temperature)
379
- response_json = json.dumps(response)
380
-
381
- success = self.db.set(key, request_json, response_json)
382
-
383
- # Check if we should sync to HF
384
- if success and self.hf_repo_id and (time.time() - self.last_sync_time > self.cache_sync_interval):
385
- try:
386
- self.sync_to_hf()
387
- self.last_sync_time = time.time()
388
- except Exception as e:
389
- logger.error(f"Failed to sync cache to HF dataset: {e}")
390
-
391
- def _load_cache_from_hf(self) -> None:
392
- """Load cache from HF dataset if it exists and merge with local cache."""
393
- if not self.hf_repo_id:
394
- return
395
-
396
- try:
397
- # Check for new commits before loading the dataset
398
- ds_path = (self.cache_dir / "hf_cache").as_posix()
399
- dataset = load_dataset_from_hf(self.hf_repo_id, ds_path)["train"]
400
- if not dataset:
401
- logger.info("No new items to merge from HF dataset")
402
- return
403
-
404
- existing_keys = self.db.get_existing_keys()
405
-
406
- logger.info(f"Found {len(dataset)} items in HF dataset. Existing keys: {len(existing_keys)}")
407
-
408
- # Prepare batch items for insertion
409
- items_to_insert = []
410
- for item in dataset:
411
- key = item["key"]
412
- # Only update if not in local cache to prioritize local changes
413
- if key in existing_keys:
414
- continue
415
- # Create request JSON
416
- request_data = {
417
- "model": item["model"],
418
- "system": item["system"],
419
- "prompt": item["prompt"],
420
- "temperature": item["temperature"],
421
- "response_format": None, # We can't fully reconstruct this
422
- }
423
-
424
- items_to_insert.append(
425
- (
426
- key,
427
- json.dumps(request_data),
428
- item["response"], # This is already a JSON string
429
- )
430
- )
431
- logger.info(
432
- f"Inserting item: {key} with temperature: {item['temperature']} and response: {item['response']}"
433
- )
434
-
435
- # Bulk insert new items
436
- if items_to_insert:
437
- inserted_count = self.db.bulk_insert(items_to_insert)
438
- logger.info(f"Merged {inserted_count} items from HF dataset into SQLite cache")
439
- else:
440
- logger.info("No new items to merge from HF dataset")
441
- except Exception as e:
442
- logger.warning(f"Could not load cache from HF dataset: {e}")
443
-
444
- def get_all_entries(self) -> dict[str, dict[str, Any]]:
445
- """Get all cache entries from the database."""
446
- cache = self.db.get_all_entries()
447
- entries = {}
448
- for key, entry in cache.items():
449
- request = json.loads(entry["request"])
450
- response = json.loads(entry["response"])
451
- entries[key] = {"request": request, "response": response}
452
- return entries
453
-
454
- def sync_to_hf(self) -> None:
455
- """Sync cache to HF dataset."""
456
- if not self.hf_repo_id:
457
- return
458
-
459
- self._load_cache_from_hf()
460
-
461
- # Get all entries from the database
462
- cache = self.db.get_all_entries()
463
-
464
- # Convert cache to dataset format
465
- entries = []
466
- for key, entry in cache.items():
467
- request = json.loads(entry["request"])
468
- response_str = entry["response"]
469
- entries.append(
470
- {
471
- "key": key,
472
- "model": request["model"],
473
- "system": request["system"],
474
- "prompt": request["prompt"],
475
- "response_format": request["response_format"],
476
- "temperature": request["temperature"],
477
- "response": response_str,
478
- }
479
- )
480
-
481
- # Create and push dataset
482
- dataset = Dataset.from_list(entries)
483
- dataset.push_to_hub(self.hf_repo_id, private=True)
484
- logger.info(f"Synced {len(cache)} cached items to HF dataset {self.hf_repo_id}")
485
-
486
- def clear(self) -> None:
487
- """Clear all cache entries."""
488
- self.db.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/workflows/llms.py DELETED
@@ -1,285 +0,0 @@
1
- # %%
2
-
3
- import json
4
- import os
5
- from typing import Any, Optional
6
-
7
- import cohere
8
- import numpy as np
9
- from langchain_anthropic import ChatAnthropic
10
- from langchain_cohere import ChatCohere
11
- from langchain_core.language_models import BaseChatModel
12
- from langchain_openai import ChatOpenAI
13
- from loguru import logger
14
- from openai import OpenAI
15
- from pydantic import BaseModel, Field
16
- from pydantic._internal._core_utils import CoreSchemaOrField, is_core_schema
17
- from pydantic.json_schema import GenerateJsonSchema
18
- from rich import print as rprint
19
-
20
- # Initialize global cache
21
- from src.envs import CACHE_PATH, LLM_CACHE_REPO
22
-
23
- from .configs import AVAILABLE_MODELS
24
- from .llmcache import LLMCache
25
-
26
- llm_cache = LLMCache(cache_dir=CACHE_PATH, hf_repo=LLM_CACHE_REPO)
27
-
28
-
29
- class CohereSchemaGenerator(GenerateJsonSchema):
30
- """Generates JSON schema for Cohere models without default titles."""
31
-
32
- def field_title_should_be_set(self, schema: CoreSchemaOrField) -> bool:
33
- return_value = super().field_title_should_be_set(schema)
34
- if return_value and is_core_schema(schema):
35
- return False
36
- return return_value
37
-
38
-
39
- def _openai_is_json_mode_supported(model_name: str) -> bool:
40
- if model_name.startswith("gpt-4"):
41
- return True
42
- if model_name.startswith("gpt-3.5"):
43
- return False
44
- logger.warning(f"OpenAI model {model_name} is not available in this app, skipping JSON mode, returning False")
45
- return False
46
-
47
-
48
- class LLMOutput(BaseModel):
49
- content: str = Field(description="The content of the response")
50
- logprob: Optional[float] = Field(None, description="The log probability of the response")
51
-
52
-
53
- def _get_langchain_chat_output(llm: BaseChatModel, system: str, prompt: str) -> str:
54
- output = llm.invoke([("system", system), ("human", prompt)])
55
- ai_message = output["raw"]
56
- content = {"content": ai_message.content, "tool_calls": ai_message.tool_calls}
57
- content_str = json.dumps(content)
58
- return {"content": content_str, "output": output["parsed"].model_dump()}
59
-
60
-
61
- def _cohere_completion(
62
- model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True
63
- ) -> str:
64
- messages = [
65
- {"role": "system", "content": system},
66
- {"role": "user", "content": prompt},
67
- ]
68
- client = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY"))
69
- schema = response_model.model_json_schema(schema_generator=CohereSchemaGenerator)
70
- if "title" in schema:
71
- del schema["title"]
72
- response_format = {
73
- "type": "json_object",
74
- "schema": schema,
75
- }
76
- response = client.chat(
77
- model=model,
78
- messages=messages,
79
- response_format=response_format,
80
- logprobs=logprobs,
81
- temperature=temperature,
82
- )
83
- output = {}
84
- output["content"] = response.message.content[0].text
85
- output["output"] = response_model.model_validate_json(response.message.content[0].text).model_dump()
86
- if logprobs:
87
- output["logprob"] = sum(lp.logprobs[0] for lp in response.logprobs)
88
- output["prob"] = np.exp(output["logprob"])
89
- return output
90
-
91
-
92
- def _openai_langchain_completion(
93
- model: str, system: str, prompt: str, response_model, temperature: float | None = None
94
- ) -> str:
95
- llm = ChatOpenAI(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True)
96
- return _get_langchain_chat_output(llm, system, prompt)
97
-
98
-
99
- def _openai_completion(
100
- model: str, system: str, prompt: str, response_model, temperature: float | None = None, logprobs: bool = True
101
- ) -> str:
102
- messages = [
103
- {"role": "system", "content": system},
104
- {"role": "user", "content": prompt},
105
- ]
106
- client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
107
- response = client.beta.chat.completions.parse(
108
- model=model,
109
- messages=messages,
110
- response_format=response_model,
111
- logprobs=logprobs,
112
- temperature=temperature,
113
- )
114
- output = {}
115
- output["content"] = response.choices[0].message.content
116
- output["output"] = response.choices[0].message.parsed.model_dump()
117
- if logprobs:
118
- output["logprob"] = sum(lp.logprob for lp in response.choices[0].logprobs.content)
119
- output["prob"] = np.exp(output["logprob"])
120
- return output
121
-
122
-
123
- def _anthropic_completion(
124
- model: str, system: str, prompt: str, response_model, temperature: float | None = None
125
- ) -> str:
126
- llm = ChatAnthropic(model=model, temperature=temperature).with_structured_output(response_model, include_raw=True)
127
- return _get_langchain_chat_output(llm, system, prompt)
128
-
129
-
130
- def _llm_completion(
131
- model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False
132
- ) -> dict[str, Any]:
133
- """
134
- Generate a completion from an LLM provider with structured output without caching.
135
-
136
- Args:
137
- model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
138
- system (str): System prompt/instructions for the model
139
- prompt (str): User prompt/input
140
- response_format: Pydantic model defining the expected response structure
141
- logprobs (bool, optional): Whether to return log probabilities. Defaults to False.
142
- Note: Not supported by Anthropic models.
143
-
144
- Returns:
145
- dict: Contains:
146
- - output: The structured response matching response_format
147
- - logprob: (optional) Sum of log probabilities if logprobs=True
148
- - prob: (optional) Exponential of logprob if logprobs=True
149
-
150
- Raises:
151
- ValueError: If logprobs=True with Anthropic models
152
- """
153
- model_name = AVAILABLE_MODELS[model]["model"]
154
- provider = model.split("/")[0]
155
- if provider == "Cohere":
156
- return _cohere_completion(model_name, system, prompt, response_format, temperature, logprobs)
157
- elif provider == "OpenAI":
158
- if _openai_is_json_mode_supported(model_name):
159
- return _openai_completion(model_name, system, prompt, response_format, temperature, logprobs)
160
- elif logprobs:
161
- raise ValueError(f"{model} does not support logprobs feature.")
162
- else:
163
- return _openai_langchain_completion(model_name, system, prompt, response_format, temperature)
164
- elif provider == "Anthropic":
165
- if logprobs:
166
- raise ValueError("Anthropic models do not support logprobs")
167
- return _anthropic_completion(model_name, system, prompt, response_format, temperature)
168
- else:
169
- raise ValueError(f"Provider {provider} not supported")
170
-
171
-
172
- def completion(
173
- model: str, system: str, prompt: str, response_format, temperature: float | None = None, logprobs: bool = False
174
- ) -> dict[str, Any]:
175
- """
176
- Generate a completion from an LLM provider with structured output with caching.
177
-
178
- Args:
179
- model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
180
- system (str): System prompt/instructions for the model
181
- prompt (str): User prompt/input
182
- response_format: Pydantic model defining the expected response structure
183
- logprobs (bool, optional): Whether to return log probabilities. Defaults to False.
184
- Note: Not supported by Anthropic models.
185
-
186
- Returns:
187
- dict: Contains:
188
- - output: The structured response matching response_format
189
- - logprob: (optional) Sum of log probabilities if logprobs=True
190
- - prob: (optional) Exponential of logprob if logprobs=True
191
-
192
- Raises:
193
- ValueError: If logprobs=True with Anthropic models
194
- """
195
- if model not in AVAILABLE_MODELS:
196
- raise ValueError(f"Model {model} not supported")
197
- if logprobs and not AVAILABLE_MODELS[model].get("logprobs", False):
198
- logger.warning(f"{model} does not support logprobs feature, setting logprobs to False")
199
- logprobs = False
200
-
201
- # Check cache first
202
- cached_response = llm_cache.get(model, system, prompt, response_format, temperature)
203
- if cached_response and (not logprobs or cached_response.get("logprob")):
204
- logger.debug(f"Cache hit for model {model}")
205
- return cached_response
206
-
207
- logger.debug(f"Cache miss for model {model}, calling API. Logprobs: {logprobs}")
208
-
209
- # Continue with the original implementation for cache miss
210
- response = _llm_completion(model, system, prompt, response_format, temperature, logprobs)
211
-
212
- # Update cache with the new response
213
- llm_cache.set(
214
- model,
215
- system,
216
- prompt,
217
- response_format,
218
- temperature,
219
- response,
220
- )
221
-
222
- return response
223
-
224
-
225
- # %%
226
- if __name__ == "__main__":
227
- from tqdm import tqdm
228
-
229
- class ExplainedAnswer(BaseModel):
230
- """
231
- The answer to the question and a terse explanation of the answer.
232
- """
233
-
234
- answer: str = Field(description="The short answer to the question")
235
- explanation: str = Field(description="5 words terse best explanation of the answer.")
236
-
237
- models = list(AVAILABLE_MODELS.keys())[:1] # Just use the first model for testing
238
- system = "You are an accurate and concise explainer of scientific concepts."
239
- prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed."
240
-
241
- llm_cache = LLMCache(cache_dir=".", hf_repo="qanta-challenge/advcal-llm-cache", reset=True)
242
-
243
- # First call - should be a cache miss
244
- logger.info("First call - should be a cache miss")
245
- for model in tqdm(models):
246
- response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
247
- rprint(response)
248
-
249
- # Second call - should be a cache hit
250
- logger.info("Second call - should be a cache hit")
251
- for model in tqdm(models):
252
- response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
253
- rprint(response)
254
-
255
- # Slightly different prompt - should be a cache miss
256
- logger.info("Different prompt - should be a cache miss")
257
- prompt2 = "Which planet is closest to the sun? Answer directly."
258
- for model in tqdm(models):
259
- response = completion(model, system, prompt2, ExplainedAnswer, logprobs=False)
260
- rprint(response)
261
-
262
- # Get cache entries count from SQLite
263
- try:
264
- cache_entries = llm_cache.get_all_entries()
265
- logger.info(f"Cache now has {len(cache_entries)} items")
266
- except Exception as e:
267
- logger.error(f"Failed to get cache entries: {e}")
268
-
269
- # Test adding entry with temperature parameter
270
- logger.info("Testing with temperature parameter")
271
- response = completion(models[0], system, "What is Mars?", ExplainedAnswer, temperature=0.7, logprobs=False)
272
- rprint(response)
273
-
274
- # Demonstrate forced sync to HF if repo is configured
275
- if llm_cache.hf_repo_id:
276
- logger.info("Forcing sync to HF dataset")
277
- try:
278
- llm_cache.sync_to_hf()
279
- logger.info("Successfully synced to HF dataset")
280
- except Exception as e:
281
- logger.exception(f"Failed to sync to HF: {e}")
282
- else:
283
- logger.info("HF repo not configured, skipping sync test")
284
-
285
- # %%
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/workflows/qb_agents.py DELETED
@@ -1,232 +0,0 @@
1
- import time
2
- from typing import Any, Iterable, TypedDict
3
-
4
- from loguru import logger
5
-
6
- from .executors import WorkflowOutput, execute_workflow
7
- from .structs import TossupWorkflow, Workflow
8
-
9
-
10
- def _get_workflow_response(
11
- workflow: Workflow, available_vars: dict[str, Any], logprob_step: bool | str = False
12
- ) -> tuple[WorkflowOutput, float]:
13
- """Get response from executing a complete workflow."""
14
- start_time = time.time()
15
- workflow_output = execute_workflow(workflow, available_vars, return_full_content=True, logprob_step=logprob_step)
16
- response_time = time.time() - start_time
17
- return workflow_output, response_time
18
-
19
-
20
- class TossupResult(TypedDict):
21
- answer: str # the model's answer
22
- confidence: float # confidence score
23
- logprob: float | None # log probability of the answer
24
- buzz: bool # whether the agent buzzed
25
- question_fragment: str # prefix of the question text so far
26
- position: int # 1-indexed question run index
27
- step_contents: list[str] # string content outputs of each step
28
- response_time: float
29
- step_outputs: dict[str, Any]
30
-
31
-
32
- class BonusResult(TypedDict):
33
- answer: str
34
- confidence: float
35
- explanation: str
36
- response_time: float
37
- step_contents: list[str]
38
- step_outputs: dict[str, Any]
39
-
40
-
41
- class QuizBowlTossupAgent:
42
- """Agent for handling tossup questions with multiple steps in the workflow."""
43
-
44
- external_input_variable = "question_text"
45
- output_variables = ["answer", "confidence"]
46
-
47
- def __init__(self, workflow: TossupWorkflow):
48
- """Initialize the multi-step tossup agent.
49
-
50
- Args:
51
- workflow: The workflow containing multiple steps
52
- buzz_threshold: Confidence threshold for buzzing
53
- """
54
- self.workflow = workflow
55
- self.output_variables = list(workflow.outputs.keys())
56
-
57
- # Validate input variables
58
- if self.external_input_variable not in workflow.inputs:
59
- raise ValueError(f"External input variable {self.external_input_variable} not found in workflow inputs")
60
-
61
- # Validate output variables
62
- for out_var in self.output_variables:
63
- if out_var not in workflow.outputs:
64
- raise ValueError(f"Output variable {out_var} not found in workflow outputs")
65
-
66
- def _single_run(self, question_run: str, position: int) -> TossupResult:
67
- """Process a single question run.
68
- Args:
69
- question_run: The question run to process
70
- position: The position of the question run
71
-
72
- Returns:
73
- A TossupResult containing the answer, confidence, logprob, buzz, question fragment, position, step contents, response time, and step outputs
74
- """
75
- answer_var_step = self.workflow.outputs["answer"].split(".")[0]
76
- workflow_output, response_time = _get_workflow_response(
77
- self.workflow, {self.external_input_variable: question_run}, logprob_step=answer_var_step
78
- )
79
- final_outputs = workflow_output["final_outputs"]
80
- buzz = self.workflow.buzzer.run(final_outputs["confidence"], logprob=workflow_output["logprob"])
81
- result: TossupResult = {
82
- "position": position,
83
- "answer": final_outputs["answer"],
84
- "confidence": final_outputs["confidence"],
85
- "logprob": workflow_output["logprob"],
86
- "buzz": buzz,
87
- "question_fragment": question_run,
88
- "step_contents": workflow_output["step_contents"],
89
- "step_outputs": workflow_output["intermediate_outputs"], # Include intermediate step outputs
90
- "response_time": response_time,
91
- }
92
- return result
93
-
94
- def run(self, question_runs: list[str], early_stop: bool = True) -> Iterable[TossupResult]:
95
- """Process a tossup question and decide when to buzz based on confidence.
96
-
97
- Args:
98
- question_runs: Progressive reveals of the question text
99
- early_stop: Whether to stop after the first buzz
100
-
101
- Yields:
102
- Dict containing:
103
- - answer: The model's answer
104
- - confidence: Confidence score
105
- - buzz: Whether to buzz
106
- - question_fragment: Current question text
107
- - position: Current position in question
108
- - step_contents: String content outputs of each step
109
- - response_time: Time taken for response
110
- - step_outputs: Outputs from each step
111
- """
112
- for i, question_text in enumerate(question_runs):
113
- # Execute the complete workflow
114
- result = self._single_run(question_text, i + 1)
115
-
116
- yield result
117
-
118
- # If we've reached the confidence threshold, buzz and stop
119
- if early_stop and result["buzz"]:
120
- if i + 1 < len(question_runs):
121
- yield self._single_run(question_runs[-1], len(question_runs))
122
- return
123
-
124
-
125
- class QuizBowlBonusAgent:
126
- """Agent for handling bonus questions with multiple steps in the workflow."""
127
-
128
- external_input_variables = ["leadin", "part"]
129
- output_variables = ["answer", "confidence", "explanation"]
130
-
131
- def __init__(self, workflow: Workflow):
132
- """Initialize the multi-step bonus agent.
133
-
134
- Args:
135
- workflow: The workflow containing multiple steps
136
- """
137
- self.workflow = workflow
138
- self.output_variables = list(workflow.outputs.keys())
139
-
140
- # Validate input variables
141
- for input_var in self.external_input_variables:
142
- if input_var not in workflow.inputs:
143
- raise ValueError(f"External input variable {input_var} not found in workflow inputs")
144
-
145
- # Validate output variables
146
- for out_var in self.output_variables:
147
- if out_var not in workflow.outputs:
148
- raise ValueError(f"Output variable {out_var} not found in workflow outputs")
149
-
150
- def run(self, leadin: str, part: str) -> BonusResult:
151
- """Process a bonus part with the given leadin.
152
-
153
- Args:
154
- leadin: The leadin text for the bonus question
155
- part: The specific part text to answer
156
-
157
- Returns:
158
- Dict containing:
159
- - answer: The model's answer
160
- - confidence: Confidence score
161
- - explanation: Explanation for the answer
162
- - step_contents: String content outputs of each step
163
- - response_time: Time taken for response
164
- - step_outputs: Outputs from each step
165
- """
166
- workflow_output, response_time = _get_workflow_response(
167
- self.workflow,
168
- {
169
- "leadin": leadin,
170
- "part": part,
171
- },
172
- )
173
- final_outputs = workflow_output["final_outputs"]
174
- return {
175
- "answer": final_outputs["answer"],
176
- "confidence": final_outputs["confidence"],
177
- "explanation": final_outputs["explanation"],
178
- "step_contents": workflow_output["step_contents"],
179
- "response_time": response_time,
180
- "step_outputs": workflow_output["intermediate_outputs"], # Include intermediate step outputs
181
- }
182
-
183
-
184
- # Example usage
185
- if __name__ == "__main__":
186
- # Load the Quizbowl dataset
187
- from datasets import load_dataset
188
-
189
- from workflows.factory import create_quizbowl_bonus_workflow, create_quizbowl_tossup_workflow
190
-
191
- ds_name = "qanta-challenge/leaderboard_co_set"
192
- ds = load_dataset(ds_name, split="train")
193
-
194
- # Create the agents with multi-step workflows
195
- tossup_workflow = create_quizbowl_tossup_workflow()
196
- tossup_agent = QuizBowlTossupAgent(workflow=tossup_workflow, buzz_threshold=0.9)
197
-
198
- bonus_workflow = create_quizbowl_bonus_workflow()
199
- bonus_agent = QuizBowlBonusAgent(workflow=bonus_workflow)
200
-
201
- # Example for tossup mode
202
- print("\n=== TOSSUP MODE EXAMPLE ===")
203
- sample_question = ds[30]
204
- print(sample_question["question_runs"][-1])
205
- print(sample_question["gold_label"])
206
- print()
207
- question_runs = sample_question["question_runs"]
208
-
209
- results = tossup_agent.run(question_runs, early_stop=True)
210
- for result in results:
211
- print(result["step_contents"])
212
- print(f"Guess at position {result['position']}: {result['answer']}")
213
- print(f"Confidence: {result['confidence']}")
214
- print("Step outputs:", result["step_outputs"])
215
- if result["buzz"]:
216
- print("Buzzed!\n")
217
-
218
- # Example for bonus mode
219
- print("\n=== BONUS MODE EXAMPLE ===")
220
- sample_bonus = ds[31] # Assuming this is a bonus question
221
- leadin = sample_bonus["leadin"]
222
- parts = sample_bonus["parts"]
223
-
224
- print(f"Leadin: {leadin}")
225
- for i, part in enumerate(parts):
226
- print(f"\nPart {i + 1}: {part['part']}")
227
- result = bonus_agent.run(leadin, part["part"])
228
- print(f"Answer: {result['answer']}")
229
- print(f"Confidence: {result['confidence']}")
230
- print(f"Explanation: {result['explanation']}")
231
- print(f"Response time: {result['response_time']:.2f}s")
232
- print("Step outputs:", result["step_outputs"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/workflows/structs.py DELETED
@@ -1,370 +0,0 @@
1
- # %%
2
- from copy import deepcopy
3
- from enum import Enum
4
- from typing import Any, Literal, Optional
5
-
6
- import numpy as np
7
- from pydantic import BaseModel, Field, model_validator
8
-
9
- from .configs import AVAILABLE_MODELS
10
-
11
- """
12
- Core data structures for defining workflows and their components.
13
-
14
- This module defines the primary classes used to model workflows, steps, and their
15
- input/output fields. These data structures serve as the foundation for workflow
16
- definition, validation, and execution throughout the workflows package.
17
-
18
- The primary components are:
19
- - InputField: Represents an input to a model step with name and source variable
20
- - OutputField: Represents an output from a model step with name and type
21
- - ModelStep: Represents a single step in a workflow with inputs and outputs
22
- - Workflow: A collection of interconnected steps with defined inputs and outputs
23
-
24
- All classes use Pydantic's BaseModel for validation and serialization support.
25
- """
26
- FieldType = Literal["input", "output"]
27
-
28
-
29
- SUPPORTED_TYPES = Literal["str", "int", "float", "bool", "list[str]", "list[int]", "list[float]", "list[bool]"]
30
- """Supported field types for input and output fields"""
31
-
32
-
33
- class InputField(BaseModel):
34
- """
35
- Defines an input field for a model step.
36
-
37
- An input field specifies what data a step requires, where it comes from,
38
- and optional pre-processing to apply before use.
39
-
40
- Attributes:
41
- name: The name of the input field within the step's context
42
- description: Human-readable description of the input's purpose
43
- variable: Reference to the source variable (format: "{step_id}.{field_name}" or external input name)
44
- func: Optional function name to transform the input value before use
45
- """
46
-
47
- name: str
48
- description: str
49
- variable: str
50
-
51
- # function to call on the input before passing it to the model
52
- func: str | None = None
53
-
54
- class Config:
55
- frozen = True
56
-
57
-
58
- class OutputField(BaseModel):
59
- """
60
- Defines an output field produced by a model step.
61
-
62
- An output field specifies a value that the step will produce, including
63
- its data type and optional post-processing.
64
-
65
- Attributes:
66
- name: The name of the output field within the step's context
67
- description: Human-readable description of the output's purpose
68
- type: The data type of the output (one of SUPPORTED_TYPES)
69
- func: Optional function name to transform the raw output value
70
- """
71
-
72
- name: str
73
- type: SUPPORTED_TYPES = Field(default="str")
74
- description: str
75
-
76
- # function to call on the output string from the model
77
- func: str | None = None
78
-
79
- class Config:
80
- frozen = True
81
-
82
-
83
- class CallType(str, Enum):
84
- LLM = "llm"
85
- SEARCH = "search"
86
- PYTHON_FUNC = "python_func"
87
-
88
-
89
- class ModelStep(BaseModel):
90
- """
91
- Represents a single step in a workflow.
92
-
93
- A model step encapsulates the details of a specific operation within a workflow,
94
- including what model to use, what inputs it requires, and what outputs it produces.
95
-
96
- Attributes:
97
- id: Unique identifier for this step within a workflow
98
- model: The model to use for this step (e.g., "gpt-4")
99
- provider: The provider of the model (e.g., "openai")
100
- call_type: The type of operation (e.g., "llm", "search")
101
- system_prompt: Instructions for the model
102
- input_fields: List of input fields required by this step
103
- output_fields: List of output fields produced by this step
104
- """
105
-
106
- id: str
107
- name: str
108
- model: str
109
- provider: str
110
- call_type: CallType = CallType.LLM
111
-
112
- # TODO: Validate that this is not None for call_type = llm
113
- temperature: Optional[float] = None
114
-
115
- system_prompt: str
116
- input_fields: list[InputField]
117
- output_fields: list[OutputField]
118
-
119
- class Config:
120
- use_enum_values = True
121
-
122
- def fields(self, field_type: FieldType) -> list[InputField | OutputField]:
123
- return self.input_fields if field_type == "input" else self.output_fields
124
-
125
- def get_full_model_name(self) -> str:
126
- return f"{self.provider}/{self.model}"
127
-
128
- def get_produced_variables(self) -> list[str]:
129
- return [f"{self.id}.{field.name}" for field in self.output_fields if field.name]
130
-
131
- def update(self, update: dict[str, Any]) -> "ModelStep":
132
- """Returns a new copy with the updated properties."""
133
- return self.model_copy(update=update)
134
-
135
- def update_property(self, field: str, value: Any) -> "ModelStep":
136
- "Update the `field` key of the model step with `value`."
137
- return self.update({field: value})
138
-
139
- def update_field(self, field_type: FieldType, index: int, key: str, value: str) -> "ModelStep":
140
- """Update a specific field of an input or output field at the given index."""
141
- if field_type == "input":
142
- fields = self.input_fields
143
- elif field_type == "output":
144
- fields = self.output_fields
145
- else:
146
- raise ValueError(f"Invalid field type: {field_type}")
147
-
148
- if index < len(fields):
149
- fields[index] = fields[index].model_copy(update={key: value})
150
- return self.model_copy()
151
-
152
- @staticmethod
153
- def create_new_field(field_type: FieldType, input_var: str | None = None) -> InputField | OutputField:
154
- if field_type == "input":
155
- return InputField(name="", description="", variable=input_var)
156
- elif field_type == "output":
157
- return OutputField(name="", description="")
158
- else:
159
- raise ValueError(f"Invalid field type: {field_type}")
160
-
161
- def add_field(self, field_type: FieldType, index: int = -1, input_var: str | None = None) -> "ModelStep":
162
- """Add a new field to the state and update visibility.
163
-
164
- Args:
165
- field_type: Type of field to add ('input' or 'output').
166
- index: Position to insert the new field (-1 to append).
167
- Returns:
168
- A new ModelStep with the updated fields.
169
- """
170
- if field_type == "input":
171
- fields = deepcopy(self.input_fields)
172
- new_field = ModelStep.create_new_field(field_type, input_var)
173
- fields.insert(index + 1, new_field) if index != -1 else fields.append(new_field)
174
- return self.model_copy(update={"input_fields": fields})
175
- else:
176
- fields = deepcopy(self.output_fields)
177
- new_field = ModelStep.create_new_field(field_type)
178
- fields.insert(index + 1, new_field) if index != -1 else fields.append(new_field)
179
- return self.model_copy(update={"output_fields": fields})
180
-
181
- def delete_field(self, field_type: FieldType, index: int) -> "ModelStep":
182
- """
183
- Delete an input or output field from the state and update visibility.
184
-
185
- Args:
186
- field_type: Type of field to delete ('input' or 'output').
187
- index: Index of the field to delete. [-1 to delete the last field]
188
-
189
- Returns:
190
- A new ModelStep with the updated fields.
191
- """
192
- fields = self.input_fields if field_type == "input" else self.output_fields
193
- fields = deepcopy(fields)
194
- fields.pop(index)
195
- return self.model_copy(update={"input_fields": fields} if field_type == "input" else {"output_fields": fields})
196
-
197
-
198
- class Workflow(BaseModel):
199
- """
200
- Represents a complete workflow composed of interconnected steps.
201
-
202
- A workflow defines a directed acyclic graph of model steps, where outputs
203
- from earlier steps can be used as inputs to later steps.
204
-
205
- Attributes:
206
- inputs: List of input variables required by the workflow
207
- outputs: List of output variables produced by the workflow
208
- steps: Dictionary mapping step IDs to ModelStep instances
209
-
210
- The inputs and outputs lists use the format "{step_id}.{field_name}"
211
- to uniquely identify variables within the workflow.
212
- """
213
-
214
- # variables of form {node}.{field}
215
- inputs: list[str] = Field(default_factory=list)
216
-
217
- # variables of form {node}.{field}
218
- outputs: dict[str, str | None] = Field(default_factory=dict)
219
- steps: dict[str, ModelStep] = Field(default_factory=dict)
220
-
221
- def model_dump(self, *args, **kwargs):
222
- data = super().model_dump(*args, **kwargs)
223
- if "steps" in data:
224
- data["steps"] = list(data["steps"].values())
225
- return data
226
-
227
- @model_validator(mode="before")
228
- def dictify_steps(cls, data):
229
- if "steps" in data and isinstance(data["steps"], list):
230
- steps_dict = {}
231
- for step in data["steps"]:
232
- if isinstance(step, ModelStep):
233
- step_id = step.id
234
- else:
235
- step_id = step["id"]
236
- if step_id in steps_dict:
237
- raise ValueError(f"Duplicate step ID: {step_id}")
238
- steps_dict[step_id] = step
239
- data["steps"] = steps_dict
240
- return data
241
-
242
- def get_step_variables(self, step_id: str) -> list[str]:
243
- """Get all variables from a specific step."""
244
- step = self.steps[step_id]
245
- variables = []
246
- for output in step.output_fields:
247
- if output.name == "":
248
- continue
249
- output_var = f"{step.id}.{output.name}"
250
- variables.append(output_var)
251
- return variables
252
-
253
- def get_available_variables(self) -> list[str]:
254
- """Get all output variables from all steps."""
255
- variables = set(self.inputs)
256
- for step in self.steps.values():
257
- variables.update(self.get_step_variables(step.id))
258
- return list(variables)
259
-
260
- def get_step_model_selections(self) -> dict[str, str]:
261
- """Get all model selections for all steps."""
262
- return {step_id: step.get_full_model_name() for step_id, step in self.steps.items()}
263
-
264
- def get_output_model_selections(self) -> dict[str, str]:
265
- """Get all output model selections for all steps."""
266
- return {
267
- output_var: target_var.split(".")[0] if target_var else None
268
- for output_var, target_var in self.outputs.items()
269
- }
270
-
271
- # Step update method
272
-
273
- def add_step(self, step: ModelStep) -> "Workflow":
274
- """Add a step to the workflow."""
275
- steps = self.steps | {step.id: step}
276
- return self.model_copy(update={"steps": steps})
277
-
278
- def remove_step(self, step_id: str) -> "Workflow":
279
- """Remove a step from the workflow."""
280
- self.steps.pop(step_id)
281
- workflow = self.model_copy(update={"steps": self.steps})
282
- workflow.refresh_output_variables()
283
- return workflow
284
-
285
- def update_step(self, step: ModelStep) -> "Workflow":
286
- """Update a step in the workflow."""
287
- self.steps[step.id] = step
288
- steps = self.steps | {step.id: step}
289
- workflow = self.model_copy(update={"steps": steps})
290
- workflow.refresh_output_variables()
291
- return workflow
292
-
293
- # Output variables
294
- def refresh_output_variables(self) -> "Workflow":
295
- """Refresh the output variables for the workflow."""
296
- produced_variables = self.get_available_variables()
297
- self.outputs = {k: (v if v in produced_variables else None) for k, v in self.outputs.items()}
298
- return self
299
-
300
-
301
- class BuzzerMethod(str, Enum):
302
- AND = "AND"
303
- OR = "OR"
304
-
305
-
306
- class Buzzer(BaseModel):
307
- """Configuration for when to buzz in a tossup question."""
308
-
309
- method: BuzzerMethod = BuzzerMethod.AND # Logic to combine thresholds
310
- confidence_threshold: float = Field(default=0.5, ge=0.0, le=1.0) # Minimum confidence to trigger a buzz
311
- prob_threshold: float | None = None # Optional log probability threshold
312
-
313
- class Config:
314
- use_enum_values = True
315
- frozen = True
316
-
317
- def update(self, **kwargs) -> "Buzzer":
318
- """Update the buzzer with the given kwargs."""
319
- return self.model_copy(update=kwargs)
320
-
321
- def run(self, confidence: float, prob: float | None = None, logprob: float | None = None) -> bool:
322
- """Run the buzzer logic."""
323
- if logprob is not None and prob is not None:
324
- raise ValueError("Cannot provide both logprob and prob")
325
- if self.prob_threshold is None:
326
- return confidence >= self.confidence_threshold
327
- if logprob is None and prob is None:
328
- raise ValueError("Must provide either logprob or prob if prob_threshold is not None")
329
- prob = prob or float(np.exp(logprob))
330
- if self.method == BuzzerMethod.AND:
331
- return confidence >= self.confidence_threshold and prob >= self.prob_threshold
332
- elif self.method == BuzzerMethod.OR:
333
- return confidence >= self.confidence_threshold or prob >= self.prob_threshold
334
- else:
335
- raise ValueError(f"Invalid buzzer method: {self.method}")
336
-
337
- @model_validator(mode="after")
338
- def validate_method_with_log_prob(cls, data):
339
- """Validate that if prob_threshold is None, method must be 'and'."""
340
- if data.prob_threshold is None and data.method != BuzzerMethod.AND:
341
- raise ValueError("If prob_threshold is None, method must be 'and'")
342
- return data
343
-
344
-
345
- class TossupWorkflow(Workflow):
346
- """Workflow specialized for tossup questions with buzzing capability."""
347
-
348
- buzzer: Buzzer = Field(default_factory=Buzzer)
349
-
350
- def get_answer_model(self, answer_var: str | None = None) -> str | None:
351
- answer_var = answer_var or self.outputs["answer"]
352
- if answer_var is None:
353
- return None
354
- step_id = answer_var.split(".")[0]
355
- return self.steps[step_id].get_full_model_name()
356
-
357
- def is_token_probs_supported(self, answer_var: str | None = None) -> bool:
358
- model_name = self.get_answer_model(answer_var)
359
- if model_name is None:
360
- return True
361
- return AVAILABLE_MODELS[model_name].get("logprobs", False)
362
-
363
- def update_buzzer(self, buzzer: Buzzer) -> "TossupWorkflow":
364
- """Update the buzzer."""
365
- return self.model_copy(update={"buzzer": buzzer})
366
-
367
- def refresh_buzzer(self) -> "TossupWorkflow":
368
- if not self.is_token_probs_supported():
369
- return self.update_buzzer(self.buzzer.update(prob_threshold=None, method="AND"))
370
- return self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/workflows/utils.py DELETED
@@ -1,195 +0,0 @@
1
- from collections import deque
2
- from typing import Any, Iterable
3
-
4
- from .errors import CyclicDependencyError, UnknownVariableError, WorkflowError
5
- from .structs import Workflow
6
-
7
- """
8
- Utilities for workflow dependency management and execution order determination.
9
-
10
- This module provides functions for analyzing workflows, determining dependencies between steps,
11
- and calculating the correct execution order to ensure all dependencies are satisfied.
12
- Key functionality includes:
13
-
14
- - Variable to step mapping: Identifying which step produces each variable
15
- - Dependency graph creation: Building a graph representing dependencies between steps
16
- - Topological sorting: Determining a valid execution order based on dependencies
17
- - Cycle detection: Identifying cyclic dependencies that would prevent execution
18
-
19
- These utilities form the foundation for workflow validation and execution in the
20
- workflow_executor module.
21
- """
22
-
23
-
24
- def _create_variable_step_mapping(workflow: Workflow) -> dict[str, str]:
25
- """
26
- Creates a mapping from produced variable names to the model step that produces them.
27
-
28
- Args:
29
- workflow (Workflow): The workflow containing steps and their input/output fields.
30
-
31
- Returns:
32
- dict[str, str]: A dictionary where keys are variable names (formatted as "{step_id}.{output name}")
33
- and values are the step IDs that produce them.
34
-
35
- Raises:
36
- WorkflowError: If there are duplicate step IDs or if a variable is produced by multiple steps.
37
-
38
- Example:
39
- For a workflow with steps "extract" and "summarize" each producing outputs:
40
- >>> _create_variable_step_mapping(workflow)
41
- {'extract.keywords': 'extract', 'summarize.summary': 'summarize'}
42
- """
43
- variable_step_map: dict[str, str] = {} # variable name -> step id
44
- for step_id, step in workflow.steps.items():
45
- for output in step.output_fields:
46
- var_name = f"{step_id}.{output.name}"
47
- if var_name in variable_step_map:
48
- raise WorkflowError(f"Variable '{output.name}' has duplicate entry in step {step_id}")
49
- variable_step_map[var_name] = step_id
50
- return variable_step_map
51
-
52
-
53
- def create_dependency_graph(workflow: Workflow, input_values: dict[str, Any]) -> dict[str, set[str]]:
54
- """
55
- Creates a dependency graph from a workflow.
56
-
57
- This function analyzes the workflow and determines which steps depend on others
58
- based on their input/output relationships. A step depends on another if it requires
59
- a variable that is produced by the other step. External inputs provided through
60
- input_values don't create dependencies.
61
-
62
- Args:
63
- workflow (Workflow): The workflow containing steps and their input/output fields.
64
- input_values (dict[str, Any]): A dictionary of external input values provided to the workflow.
65
-
66
- Returns:
67
- dict[str, set[str]]: A dictionary where keys are step IDs and values are sets of step IDs
68
- that the key step depends on.
69
-
70
- Raises:
71
- UnknownVariableError: If an input field references a variable that is not provided
72
- externally nor produced by any step.
73
-
74
- Example:
75
- For a workflow where step "classify" depends on output from "extract":
76
- >>> create_dependency_graph(workflow, {})
77
- {'extract': set(), 'classify': {'extract'}}
78
-
79
- With external input provided for "text" variable:
80
- >>> create_dependency_graph(workflow, {'text': 'Sample text'})
81
- {'extract': set(), 'classify': {'extract'}}
82
- """
83
- produced_by = _create_variable_step_mapping(workflow)
84
- dependencies: dict[str, set[str]] = {step_id: set() for step_id in workflow.steps}
85
- for step_id, step in workflow.steps.items():
86
- for input_field in step.input_fields:
87
- var = input_field.variable
88
- # If the variable was provided externally, then no dependency is needed.
89
- if var in input_values:
90
- continue
91
- # Otherwise, check if the variable is produced by a step.
92
- if var in produced_by:
93
- producer_step_id = produced_by[var]
94
- if producer_step_id != step_id: # Avoid self-dependency
95
- dependencies[step_id].add(producer_step_id)
96
- else:
97
- raise UnknownVariableError(f"Variable '{var}' is not provided externally nor produced by any step")
98
- return dependencies
99
-
100
-
101
- def detect_cycles(dep_graph: dict[str, Iterable[str]]) -> str | None:
102
- """Detects cycles in the dependency graph.
103
- Args:
104
- dep_graph: A dictionary where the keys are node IDs and the values are the dependent node IDs
105
- Returns:
106
- The first step id of a model_step that is part of a cycle, None if no cycles are found
107
- """
108
- # Check for cycles in step dependencies
109
- visited = set()
110
- path = set()
111
-
112
- def has_cycle(node: str) -> bool:
113
- if node in path:
114
- return True
115
- if node in visited:
116
- return False
117
-
118
- visited.add(node)
119
- path.add(node)
120
-
121
- for neighbor in dep_graph.get(node, set()):
122
- if has_cycle(neighbor):
123
- return True
124
-
125
- path.remove(node)
126
- return False
127
-
128
- # Check each step for cycles
129
- for node_id in dep_graph:
130
- if has_cycle(node_id):
131
- return node_id
132
- return None
133
-
134
-
135
- def topological_sort(dependencies: dict[str, set[str]]) -> list[str]:
136
- """
137
- Performs a topological sort on a dependency graph and detects cycles using Kahn's algorithm.
138
-
139
- A topological sort orders the steps such that for every dependency from step A to step B,
140
- step A comes before step B in the ordering. This ensures that all dependencies are satisfied
141
- when executing steps in the returned order.
142
-
143
- Args:
144
- dependencies (dict[str, set[str]]): A dictionary where each key is a node identifier and
145
- each value is a set of nodes that the key node depends on.
146
-
147
- Returns:
148
- list[str]: A list representing the nodes in topological order if no cycle is detected.
149
-
150
- Raises:
151
- CyclicDependencyError: If a cycle is detected in the graph.
152
-
153
- Example:
154
- >>> topological_sort({'A': set(), 'B': {'A'}, 'C': {'B'}})
155
- ['A', 'B', 'C']
156
-
157
- >>> topological_sort({'A': {'B'}, 'B': {'A'}}) # Cyclic dependency
158
- CyclicDependencyError
159
-
160
- Algorithm:
161
- This implementation uses Kahn's algorithm:
162
- 1. Calculate in-degree for all nodes (number of dependencies)
163
- 2. Start with nodes having 0 in-degree (no dependencies)
164
- 3. Process each node by removing its outgoing edges
165
- 4. Add newly dependency-free nodes to the processing queue
166
- 5. If not all nodes are processed, a cycle exists
167
- """
168
-
169
- nodes = list(dependencies.keys())
170
- dependents: dict[str, list[str]] = {node: [] for node in nodes}
171
- in_degree: dict[str, int] = dict.fromkeys(nodes, 0)
172
-
173
- # Calculate in-degrees and build dependents list
174
- for node, deps in dependencies.items():
175
- in_degree[node] = len(deps)
176
- for dep in deps:
177
- dependents[dep].append(node)
178
-
179
- # Initialize queue with nodes having zero in-degree
180
- queue = deque([node for node, deg in in_degree.items() if deg == 0])
181
- execution_order: list[str] = []
182
-
183
- # Process nodes in topological order
184
- while queue:
185
- current = queue.popleft()
186
- execution_order.append(current)
187
- for dep in dependents[current]:
188
- in_degree[dep] -= 1
189
- if in_degree[dep] == 0:
190
- queue.append(dep)
191
-
192
- # If execution order includes all nodes, no cycle exists
193
- if len(execution_order) != len(nodes):
194
- raise CyclicDependencyError()
195
- return execution_order
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/workflows/validators.py DELETED
@@ -1,615 +0,0 @@
1
- import keyword
2
- import re
3
- from dataclasses import dataclass
4
- from enum import Enum
5
- from typing import Optional
6
-
7
- from .structs import CallType, InputField, ModelStep, OutputField, Workflow
8
- from .utils import detect_cycles
9
-
10
- SUPPORTED_TYPES = {"str", "int", "float", "bool", "list[str]", "list[int]", "list[float]", "list[bool]"}
11
-
12
- # Constants for validation
13
- MAX_FIELD_NAME_LENGTH = 50
14
- MAX_DESCRIPTION_LENGTH = 200
15
- MAX_SYSTEM_PROMPT_LENGTH = 4000
16
- MAX_TEMPERATURE = 10.0
17
-
18
- from loguru import logger
19
-
20
-
21
- class ValidationErrorType(Enum):
22
- """Types of validation errors that can occur"""
23
-
24
- INPUTS = "inputs"
25
- OUTPUTS = "outputs"
26
- STEP = "step"
27
- DAG = "dag"
28
- VARIABLE = "variable"
29
- TYPE = "type"
30
- GENERAL = "general"
31
- NAMING = "naming"
32
- LENGTH = "length"
33
- RANGE = "range"
34
-
35
-
36
- @dataclass
37
- class ValidationError:
38
- """Represents a validation error with type and message"""
39
-
40
- error_type: ValidationErrorType
41
- message: str
42
- step_id: Optional[str] = None
43
- field_name: Optional[str] = None
44
-
45
- def __str__(self):
46
- subject = ""
47
- if self.step_id:
48
- subject = f"Model step '{self.step_id}'"
49
- if self.field_name:
50
- if self.step_id:
51
- subject = f"Field '{self.step_id}.{self.field_name}'"
52
- else:
53
- subject = f"Field '{self.field_name}'"
54
- return f"{self.error_type.value}: {subject} - {self.message}"
55
-
56
-
57
- class WorkflowValidationError(ValueError):
58
- """Base class for workflow validation errors"""
59
-
60
- def __init__(self, errors: list[ValidationError]):
61
- self.errors = errors
62
- super().__init__(f"Workflow validation failed with {len(errors)} errors")
63
-
64
-
65
- def _parse_variable_reference(var: str) -> tuple[Optional[str], str]:
66
- """Extracts step_id and field_name from variable reference"""
67
- parts = var.split(".")
68
- if len(parts) == 1:
69
- return None, parts[0]
70
- return parts[0], parts[1]
71
-
72
-
73
- def _get_step_dependencies(step: ModelStep) -> set[str]:
74
- """Gets set of step IDs that this step depends on"""
75
- deps = set()
76
- for field in step.input_fields:
77
- step_id, _ = _parse_variable_reference(field.variable)
78
- if step_id:
79
- deps.add(step_id)
80
- return deps
81
-
82
-
83
- def create_step_dep_graph(workflow: Workflow) -> dict[str, set[str]]:
84
- """Creates a dependency graph of steps"""
85
- dep_graph: dict[str, set[str]] = {}
86
- for step_id, step in workflow.steps.items():
87
- dep_graph[step_id] = _get_step_dependencies(step)
88
- return dep_graph
89
-
90
-
91
- class WorkflowValidator:
92
- """Validates workflows for correctness and consistency"""
93
-
94
- def __init__(
95
- self,
96
- min_temperature: float = 0,
97
- max_temperature: float = MAX_TEMPERATURE,
98
- max_field_name_length: int = MAX_FIELD_NAME_LENGTH,
99
- max_description_length: int = MAX_DESCRIPTION_LENGTH,
100
- max_system_prompt_length: int = MAX_SYSTEM_PROMPT_LENGTH,
101
- allowed_model_names: Optional[list[str]] = None,
102
- required_input_vars: Optional[list[str]] = None,
103
- required_output_vars: Optional[list[str]] = None,
104
- ):
105
- self.errors: list[ValidationError] = []
106
- self.workflow: Optional[Workflow] = None
107
- self.min_temperature = min_temperature
108
- self.max_temperature = max_temperature
109
- self.max_field_name_length = max_field_name_length
110
- self.max_description_length = max_description_length
111
- self.max_system_prompt_length = max_system_prompt_length
112
- self.required_input_vars = required_input_vars
113
- self.required_output_vars = required_output_vars
114
- self.allowed_model_names = set(allowed_model_names) if allowed_model_names else None
115
-
116
- def validate(self, workflow: Workflow, allow_empty: bool = False) -> bool:
117
- validated = self._validate(workflow, allow_empty)
118
- if not validated:
119
- raise WorkflowValidationError(self.errors)
120
- return True
121
-
122
- def _validate(self, workflow: Workflow, allow_empty: bool = False) -> bool:
123
- """Main validation entry point
124
- Args:
125
- workflow: The workflow to validate.
126
- allow_empty: If True, empty workflow is allowed. This flag is used to validate the intermediate states while User edits the workflow.
127
- """
128
- self.errors = []
129
- self.workflow = workflow
130
-
131
- # Basic workflow validation
132
- if not self._validate_workflow_basic(workflow, allow_empty):
133
- return False
134
-
135
- # If it's a single-step workflow, use simple validation
136
- if len(workflow.steps) == 1:
137
- return self.validate_simple_workflow(workflow, allow_empty)
138
-
139
- # Otherwise use complex validation
140
- return self.validate_complex_workflow(workflow, allow_empty)
141
-
142
- def _validate_required_inputs(self, workflow: Workflow, allow_empty: bool = False) -> bool:
143
- """Validates that the workflow has the correct inputs"""
144
-
145
- required_input_vars = self.required_input_vars or []
146
- input_vars = set(workflow.inputs)
147
- for req_var in required_input_vars:
148
- if req_var in input_vars:
149
- continue
150
- self.errors.append(
151
- ValidationError(ValidationErrorType.INPUTS, f"Workflow must have '{req_var}' as an input")
152
- )
153
- return False
154
-
155
- for input_var in input_vars:
156
- if not self._is_valid_external_input(input_var):
157
- self.errors.append(
158
- ValidationError(ValidationErrorType.VARIABLE, f"Invalid input variable format: {input_var}")
159
- )
160
- return False
161
- return True
162
-
163
- def _validate_required_outputs(self, workflow: Workflow, allow_empty: bool = False) -> bool:
164
- """Validates that the workflow has the correct outputs"""
165
-
166
- required_output_vars = self.required_output_vars or []
167
- output_vars = set(workflow.outputs)
168
- for req_var in required_output_vars:
169
- if req_var in output_vars:
170
- continue
171
- self.errors.append(
172
- ValidationError(ValidationErrorType.OUTPUTS, f"Workflow must produce '{req_var}' as an output")
173
- )
174
- return False
175
-
176
- # Validate output variables
177
- for output_name, output_var in workflow.outputs.items():
178
- logger.debug(f"Output name: {output_name}, Output var: {output_var}")
179
- if not output_var:
180
- if allow_empty:
181
- continue
182
- self.errors.append(
183
- ValidationError(ValidationErrorType.VARIABLE, f"Missing output variable for {output_name}")
184
- )
185
- return False
186
-
187
- # Check if output variable references a valid step output
188
- if not self._is_valid_variable_reference(output_var):
189
- self.errors.append(
190
- ValidationError(ValidationErrorType.VARIABLE, f"Invalid output variable reference: {output_var}")
191
- )
192
- return False
193
-
194
- # Verify the output field exists in the referenced step
195
- step_id, field_name = _parse_variable_reference(output_var)
196
- logger.debug(f"Step ID: {step_id}, Field name: {field_name}, Workflow steps: {workflow.steps.keys()}")
197
- if step_id not in workflow.steps:
198
- self.errors.append(
199
- ValidationError(ValidationErrorType.VARIABLE, f"Referenced model step '{step_id}' not found")
200
- )
201
- return False
202
-
203
- ref_step = workflow.steps[step_id]
204
- if not any(field.name == field_name for field in ref_step.output_fields):
205
- self.errors.append(
206
- ValidationError(
207
- ValidationErrorType.VARIABLE,
208
- f"Output field '{field_name}' not found in model step '{step_id}'",
209
- step_id,
210
- field_name,
211
- )
212
- )
213
- return False
214
- return True
215
-
216
- def validate_input_outputs(self, workflow: Workflow, allow_empty: bool = False) -> bool:
217
- """Validates the input and output variables"""
218
-
219
- self._validate_required_inputs(workflow, allow_empty)
220
- self._validate_required_outputs(workflow, allow_empty)
221
-
222
- # Check for atleast one input
223
- if not workflow.inputs:
224
- self.errors.append(
225
- ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one input")
226
- )
227
-
228
- # Check for atleast one output
229
- if not workflow.outputs:
230
- self.errors.append(
231
- ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one output")
232
- )
233
-
234
- return len(self.errors) == 0
235
-
236
- def validate_simple_workflow(self, workflow: Workflow, allow_empty: bool = False) -> bool:
237
- """Validates a single-step workflow"""
238
- if not self.workflow:
239
- return False
240
-
241
- # Get the single step
242
- step = next(iter(workflow.steps.values()))
243
-
244
- # Validate the step itself
245
- if not self._validate_step(step, allow_empty):
246
- return False
247
-
248
- return True
249
-
250
- def validate_complex_workflow(self, workflow: Workflow, allow_empty: bool = False) -> bool:
251
- """Validates a multi-step workflow"""
252
- if not self.workflow:
253
- return False
254
-
255
- # Validate each step
256
- for step in workflow.steps.values():
257
- if not self._validate_step(step, allow_empty):
258
- return False
259
-
260
- dep_graph = create_step_dep_graph(workflow)
261
- if cycle_step_id := detect_cycles(dep_graph):
262
- self.errors.append(
263
- ValidationError(
264
- ValidationErrorType.DAG, f"Circular dependency detected involving step: {cycle_step_id}"
265
- )
266
- )
267
- return False
268
-
269
- # Check for orphaned steps (steps that aren't used by any other step)
270
- used_steps = set()
271
- for deps in dep_graph.values():
272
- used_steps.update(deps)
273
- for step_id in workflow.steps:
274
- if step_id not in used_steps and not any(
275
- output_var and _parse_variable_reference(output_var)[0] == step_id
276
- for output_var in workflow.outputs.values()
277
- ):
278
- self.errors.append(ValidationError(ValidationErrorType.DAG, f"Orphaned step detected: {step_id}"))
279
- return False
280
-
281
- # Validate variable dependencies
282
- if not self._validate_variable_dependencies(workflow):
283
- return False
284
-
285
- return True
286
-
287
- def _validate_workflow_basic(self, workflow: Workflow, allow_empty: bool = False) -> bool:
288
- """Validates basic workflow properties"""
289
-
290
- # Check the workflow inputs and outputs
291
- if not self.validate_input_outputs(workflow, allow_empty):
292
- return False
293
-
294
- # Check for empty workflow
295
- if not workflow.steps:
296
- if allow_empty:
297
- return True
298
- self.errors.append(ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one step"))
299
- return False
300
-
301
- # Check for step ID consistency
302
- for step_id, step in workflow.steps.items():
303
- if step_id != step.id:
304
- self.errors.append(
305
- ValidationError(ValidationErrorType.STEP, f"Step ID mismatch: {step_id} != {step.id}", step_id)
306
- )
307
- return False
308
- return True
309
-
310
- def _validate_step(self, step: ModelStep, allow_empty: bool = False) -> bool:
311
- """Validates a single step"""
312
- # Validate required fields
313
-
314
- model_name = step.get_full_model_name()
315
-
316
- if model_name == "/" and not allow_empty:
317
- self.errors.append(
318
- ValidationError(ValidationErrorType.STEP, "Model name and provider cannot be empty", step.id)
319
- )
320
- return False
321
-
322
- # Check if the model names are allowed
323
- if self.allowed_model_names and model_name not in self.allowed_model_names:
324
- self.errors.append(
325
- ValidationError(ValidationErrorType.STEP, f"Model name '{model_name}' is not allowed", step.id)
326
- )
327
- return False
328
-
329
- if not step.id or not step.call_type:
330
- self.errors.append(ValidationError(ValidationErrorType.STEP, "Step missing required fields", step.id))
331
- return False
332
-
333
- # Validate step ID and name
334
- if not self._is_valid_identifier(step.id):
335
- self.errors.append(
336
- ValidationError(
337
- ValidationErrorType.NAMING,
338
- f"Invalid step ID format: {step.id}. Must be a valid identifier.",
339
- step.id,
340
- )
341
- )
342
- return False
343
-
344
- # Validate temperature for LLM call type
345
- if step.call_type == CallType.LLM:
346
- if step.temperature is None:
347
- self.errors.append(
348
- ValidationError(ValidationErrorType.STEP, "LLM step must specify temperature", step.id)
349
- )
350
- return False
351
-
352
- if not self.min_temperature <= step.temperature <= self.max_temperature:
353
- self.errors.append(
354
- ValidationError(
355
- ValidationErrorType.RANGE,
356
- f"Temperature must be between {self.min_temperature} and {self.max_temperature}",
357
- step.id,
358
- )
359
- )
360
- return False
361
-
362
- # Validate system prompt for LLM call type
363
- if step.call_type == CallType.LLM:
364
- if not step.system_prompt:
365
- self.errors.append(
366
- ValidationError(ValidationErrorType.STEP, "LLM step must specify system prompt", step.id)
367
- )
368
- return False
369
-
370
- if len(step.system_prompt) > self.max_system_prompt_length:
371
- self.errors.append(
372
- ValidationError(
373
- ValidationErrorType.LENGTH,
374
- f"System prompt exceeds maximum length of {self.max_system_prompt_length} characters",
375
- step.id,
376
- )
377
- )
378
- return False
379
-
380
- # Validate input fields
381
- input_names = set()
382
- for field in step.input_fields:
383
- if not self._validate_input_field(field, allow_empty):
384
- return False
385
- if field.name in input_names:
386
- self.errors.append(
387
- ValidationError(
388
- ValidationErrorType.STEP, f"Duplicate input field name: {field.name}", step.id, field.name
389
- )
390
- )
391
- return False
392
- input_names.add(field.name)
393
-
394
- # Validate output fields
395
- output_names = set()
396
- for field in step.output_fields:
397
- if not self._validate_output_field(field, allow_empty):
398
- return False
399
- if field.name in output_names:
400
- self.errors.append(
401
- ValidationError(
402
- ValidationErrorType.STEP, f"Duplicate output field name: {field.name}", step.id, field.name
403
- )
404
- )
405
- return False
406
- output_names.add(field.name)
407
-
408
- return True
409
-
410
- def _validate_input_field(self, field: InputField, allow_empty: bool = False) -> bool:
411
- """Validates an input field"""
412
- # Validate required fields
413
- if not field.name or not field.description or not field.variable:
414
- self.errors.append(
415
- ValidationError(ValidationErrorType.STEP, "Input field missing required fields", field_name=field.name)
416
- )
417
- return False
418
-
419
- # Validate field name
420
- if not self._is_valid_identifier(field.name, allow_empty):
421
- self.errors.append(
422
- ValidationError(
423
- ValidationErrorType.NAMING,
424
- f"Invalid field name format: {field.name}. Must be a valid Python identifier.",
425
- field_name=field.name,
426
- )
427
- )
428
- return False
429
-
430
- # Validate field name length
431
- if len(field.name) > self.max_field_name_length:
432
- self.errors.append(
433
- ValidationError(
434
- ValidationErrorType.LENGTH,
435
- f"Field name exceeds maximum length of {self.max_field_name_length} characters",
436
- field_name=field.name,
437
- )
438
- )
439
- return False
440
-
441
- # Validate description length
442
- if len(field.description) > self.max_description_length:
443
- self.errors.append(
444
- ValidationError(
445
- ValidationErrorType.LENGTH,
446
- f"Description exceeds maximum length of {self.max_description_length} characters",
447
- field_name=field.name,
448
- )
449
- )
450
- return False
451
-
452
- # Validate variable reference
453
- if not self._is_valid_variable_reference(field.variable):
454
- self.errors.append(
455
- ValidationError(
456
- ValidationErrorType.VARIABLE,
457
- f"Invalid variable reference: {field.variable}",
458
- field_name=field.name,
459
- )
460
- )
461
- return False
462
-
463
- return True
464
-
465
- def _validate_output_field(self, field: OutputField, allow_empty: bool = False) -> bool:
466
- """Validates an output field"""
467
- # Validate required fields
468
- if not field.name or not field.description:
469
- self.errors.append(
470
- ValidationError(
471
- ValidationErrorType.STEP, "Output field missing required fields", field_name=field.name
472
- )
473
- )
474
- return False
475
-
476
- # Validate field name
477
- if not self._is_valid_identifier(field.name, allow_empty):
478
- self.errors.append(
479
- ValidationError(
480
- ValidationErrorType.NAMING,
481
- f"Invalid field name format: {field.name}. Must be a valid Python identifier.",
482
- field_name=field.name,
483
- )
484
- )
485
- return False
486
-
487
- # Validate field name length
488
- if len(field.name) > self.max_field_name_length:
489
- self.errors.append(
490
- ValidationError(
491
- ValidationErrorType.LENGTH,
492
- f"Field name exceeds maximum length of {self.max_field_name_length} characters",
493
- field_name=field.name,
494
- )
495
- )
496
- return False
497
-
498
- # Validate description length
499
- if len(field.description) > self.max_description_length:
500
- self.errors.append(
501
- ValidationError(
502
- ValidationErrorType.LENGTH,
503
- f"Description exceeds maximum length of {self.max_description_length} characters",
504
- field_name=field.name,
505
- )
506
- )
507
- return False
508
-
509
- # Validate type
510
- if field.type not in SUPPORTED_TYPES:
511
- self.errors.append(
512
- ValidationError(
513
- ValidationErrorType.TYPE, f"Unsupported output type: {field.type}", field_name=field.name
514
- )
515
- )
516
- return False
517
-
518
- return True
519
-
520
- def _validate_simple_workflow_variables(self, workflow: Workflow) -> bool:
521
- """Validates variables in a simple workflow"""
522
- step = next(iter(workflow.steps.values()))
523
-
524
- # Validate input variables
525
- for input_var in workflow.inputs:
526
- if not self._is_valid_external_input(input_var):
527
- self.errors.append(
528
- ValidationError(ValidationErrorType.VARIABLE, f"Invalid input variable format: {input_var}")
529
- )
530
- return False
531
-
532
- # Validate output variables
533
- for output_name, output_var in workflow.outputs.items():
534
- if output_var and not self._is_valid_variable_reference(output_var):
535
- self.errors.append(
536
- ValidationError(ValidationErrorType.VARIABLE, f"Invalid output variable reference: {output_var}")
537
- )
538
- return False
539
-
540
- return True
541
-
542
- def _validate_variable_dependencies(self, workflow: Workflow) -> bool:
543
- """Validates variable dependencies between steps"""
544
- # Build variable dependency graph
545
- var_graph: dict[str, set[str]] = {}
546
-
547
- def create_var_dep_graph(workflow: Workflow) -> dict[str, set[str]]:
548
- var_graph: dict[str, set[str]] = {}
549
- for step_id, step in workflow.steps.items():
550
- for field in step.input_fields:
551
- if field.variable not in var_graph:
552
- var_graph[field.variable] = set()
553
- # Add dependency from input variable to step's outputs
554
- for output in step.output_fields:
555
- var_graph[field.variable].add(f"{step_id}.{output.name}")
556
- return var_graph
557
-
558
- # Check for cycles in variable dependencies
559
- var_graph = create_var_dep_graph(workflow)
560
- if cycle_var := detect_cycles(var_graph):
561
- self.errors.append(
562
- ValidationError(ValidationErrorType.VARIABLE, f"Circular variable dependency detected: {cycle_var}")
563
- )
564
- return False
565
-
566
- # Validate external input existence
567
- external_inputs = set(workflow.inputs)
568
- for step in workflow.steps.values():
569
- for field in step.input_fields:
570
- step_id, field_name = _parse_variable_reference(field.variable)
571
- if not step_id and field_name not in external_inputs:
572
- self.errors.append(
573
- ValidationError(
574
- ValidationErrorType.VARIABLE,
575
- f"External input '{field_name}' not found in workflow inputs",
576
- field_name=field_name,
577
- )
578
- )
579
- return False
580
-
581
- return True
582
-
583
- def _is_valid_variable_reference(self, var: str | None, allow_empty: bool = True) -> bool:
584
- """Validates if a variable reference is properly formatted"""
585
- if not self.workflow:
586
- return False
587
- if var is None:
588
- return allow_empty
589
- parts = var.split(".")
590
- if len(parts) == 1:
591
- return True # External input
592
- if len(parts) != 2:
593
- return False
594
- step_id, field_name = parts
595
- return step_id in self.workflow.steps and any(
596
- field.name == field_name for field in self.workflow.steps[step_id].output_fields
597
- )
598
-
599
- def _is_valid_external_input(self, var: str) -> bool:
600
- """Validates if a variable is a valid external input"""
601
- if not var:
602
- return False
603
- if not self._is_valid_identifier(var):
604
- return False
605
- if keyword.iskeyword(var):
606
- return False
607
- if "." in var: # External inputs should not contain dots
608
- return False
609
- return True
610
-
611
- def _is_valid_identifier(self, name: str, allow_empty: bool = False) -> bool:
612
- """Validates if a string is a valid Python identifier"""
613
- if name and name.strip():
614
- return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name))
615
- return allow_empty