Maharshi Gor
Enhance model provider detection and add repository management script. Added support for multi step agent.
973519b
import time | |
from typing import Any, Iterable | |
# from litellm import completion | |
from llms import completion | |
from workflows.executors import execute_model_step, execute_workflow | |
from workflows.structs import ModelStep, Workflow | |
def _get_agent_response(self, prompt: str, system_prompt: str) -> dict: | |
"""Get response from the LLM model.""" | |
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}] | |
start_time = time.time() | |
response = completion( | |
model=self.model, | |
messages=messages, | |
temperature=self.temperature, | |
max_tokens=150, # Limit token usage for faster responses | |
) | |
response_time = time.time() - start_time | |
return response, response_time | |
def _get_model_step_response( | |
model_step: ModelStep, available_vars: dict[str, Any] | |
) -> tuple[dict[str, Any], str, float]: | |
"""Get response from the LLM model.""" | |
start_time = time.time() | |
response, content = execute_model_step(model_step, available_vars, return_full_content=True) | |
response_time = time.time() - start_time | |
return response, content, response_time | |
class SimpleTossupAgent: | |
external_input_variable = "question_text" | |
output_variables = ["answer", "confidence"] | |
def __init__(self, workflow: Workflow, buzz_threshold: float): | |
steps = list(workflow.steps.values()) | |
assert len(steps) == 1, "Only one step is allowed in a simple workflow" | |
self.model_step = steps[0] | |
self.buzz_threshold = buzz_threshold | |
self.output_variables = list(workflow.outputs.keys()) | |
if self.external_input_variable not in workflow.inputs: | |
raise ValueError(f"External input variable {self.external_input_variable} not found in model step inputs") | |
for out_var in self.output_variables: | |
if out_var not in workflow.outputs: | |
raise ValueError(f"Output variable {out_var} not found in the workflow outputs") | |
def run(self, question_runs: list[str], early_stop: bool = True) -> Iterable[dict]: | |
""" | |
Process a tossup question and decide when to buzz based on confidence. | |
Args: | |
question_runs: Progressive reveals of the question text | |
early_stop: Whether to stop after the first buzz | |
Yields: | |
Dict with answer, confidence, and whether to buzz | |
""" | |
for i, question_text in enumerate(question_runs): | |
response, content, response_time = _get_model_step_response( | |
self.model_step, {self.external_input_variable: question_text} | |
) | |
buzz = response["confidence"] >= self.buzz_threshold | |
result = { | |
"answer": response["answer"], | |
"confidence": response["confidence"], | |
"buzz": buzz, | |
"question_fragment": question_text, | |
"position": i + 1, | |
"full_response": content, | |
"response_time": response_time, | |
} | |
yield result | |
# If we've reached the confidence threshold, buzz and stop | |
if early_stop and buzz: | |
return | |
class SimpleBonusAgent: | |
external_input_variables = ["leadin", "part"] | |
output_variables = ["answer", "confidence", "explanation"] | |
def __init__(self, workflow: Workflow): | |
steps = list(workflow.steps.values()) | |
assert len(steps) == 1, "Only one step is allowed in a simple workflow" | |
self.model_step = steps[0] | |
self.output_variables = list(workflow.outputs.keys()) | |
# Validate input variables | |
for input_var in self.external_input_variables: | |
if input_var not in workflow.inputs: | |
raise ValueError(f"External input variable {input_var} not found in model step inputs") | |
# Validate output variables | |
for out_var in self.output_variables: | |
if out_var not in workflow.outputs: | |
raise ValueError(f"Output variable {out_var} not found in the workflow outputs") | |
def run(self, leadin: str, part: str) -> dict: | |
""" | |
Process a bonus part with the given leadin. | |
Args: | |
leadin: The leadin text for the bonus question | |
part: The specific part text to answer | |
Returns: | |
Dict with answer, confidence, and explanation | |
""" | |
response, content, response_time = _get_model_step_response( | |
self.model_step, | |
{ | |
"leadin": leadin, | |
"part": part, | |
}, | |
) | |
return { | |
"answer": response["answer"], | |
"confidence": response["confidence"], | |
"explanation": response["explanation"], | |
"full_response": content, | |
"response_time": response_time, | |
} | |
# Example usage | |
if __name__ == "__main__": | |
# Load the Quizbowl dataset | |
from datasets import load_dataset | |
from workflows.factory import create_quizbowl_bonus_step_initial_setup, create_quizbowl_simple_step_initial_setup | |
ds_name = "umdclip/leaderboard_co_set" | |
ds = load_dataset(ds_name, split="train") | |
# Create the agents | |
tossup_step = create_quizbowl_simple_step_initial_setup() | |
tossup_step.model = "gpt-4" | |
tossup_step.provider = "openai" | |
tossup_agent = SimpleTossupAgent(workflow=tossup_step, buzz_threshold=0.9) | |
bonus_step = create_quizbowl_bonus_step_initial_setup() | |
bonus_step.model = "gpt-4" | |
bonus_step.provider = "openai" | |
bonus_agent = SimpleBonusAgent(workflow=bonus_step) | |
# Example for tossup mode | |
print("\n=== TOSSUP MODE EXAMPLE ===") | |
sample_question = ds[30] | |
print(sample_question["question_runs"][-1]) | |
print(sample_question["gold_label"]) | |
print() | |
question_runs = sample_question["question_runs"] | |
results = tossup_agent.run(question_runs, early_stop=True) | |
for result in results: | |
print(result["full_response"]) | |
print(f"Guess at position {result['position']}: {result['answer']}") | |
print(f"Confidence: {result['confidence']}") | |
if result["buzz"]: | |
print("Buzzed!\n") | |
# Example for bonus mode | |
print("\n=== BONUS MODE EXAMPLE ===") | |
sample_bonus = ds[31] # Assuming this is a bonus question | |
leadin = sample_bonus["leadin"] | |
parts = sample_bonus["parts"] | |
print(f"Leadin: {leadin}") | |
for i, part in enumerate(parts): | |
print(f"\nPart {i + 1}: {part['part']}") | |
result = bonus_agent.run(leadin, part["part"]) | |
print(f"Answer: {result['answer']}") | |
print(f"Confidence: {result['confidence']}") | |
print(f"Explanation: {result['explanation']}") | |
print(f"Response time: {result['response_time']:.2f}s") | |