GaiaAgent / graph.py
sims2k's picture
Upload 8 files
ac6a4ef verified
"""Define an Agent Supervisor graph with specialized worker agents.
The supervisor routes tasks to specialized agents based on the query type.
"""
from typing import Dict, List, Literal, Optional, Union, Type, cast
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph, START, END
# Import adjusted for compatibility
from langgraph.prebuilt import create_react_agent # Try original import path first
from langgraph.types import Command
from react_agent.configuration import Configuration
from react_agent.state import WORKERS, MEMBERS, ROUTING, VERDICTS, State, Router, Plan, PlanStep, CriticVerdict
from react_agent.tools import TOOLS, tavily_tool, python_repl_tool
from react_agent.utils import load_chat_model, format_system_prompt, get_message_text
from react_agent import prompts
from react_agent.supervisor_node import supervisor_node
# Compile-time type definitions
SupervisorDestinations = Literal["planner", "critic", "researcher", "coder", "final_answer", "__end__"]
WorkerDestination = Literal["supervisor"]
# Helper function to check if a message is from a user
def is_user_message(message):
"""Check if a message is from a user regardless of message format."""
if isinstance(message, dict):
return message.get("role") == "user"
elif isinstance(message, HumanMessage):
return True
return False
# Helper function to get message content
def get_message_content(message):
"""Extract content from a message regardless of format."""
if isinstance(message, dict):
return message.get("content", "")
elif hasattr(message, "content"):
return message.content
return ""
# --- Planner node ---------------------------------------------------------
def planner_node(state: State) -> Command[WorkerDestination]:
"""Planning LLM that creates a step-by-step execution plan.
Args:
state: The current state with messages
Returns:
Command to update the state with a plan
"""
configuration = Configuration.from_context()
# Use the specific planner model
planner_llm = load_chat_model(configuration.planner_model)
# Track steps
steps_taken = state.get("steps_taken", 0)
steps_taken += 1
# Get the original user question (the latest user message)
user_messages = [m for m in state["messages"] if is_user_message(m)]
original_question = get_message_content(user_messages[-1]) if user_messages else "Help me"
# Create a chat prompt template with proper formatting
planner_prompt_template = ChatPromptTemplate.from_messages([
("system", prompts.PLANNER_PROMPT),
("user", "{question}")
])
# Format the prompt with the necessary variables
formatted_messages = planner_prompt_template.format_messages(
question=original_question,
system_time=format_system_prompt("{system_time}"),
workers=", ".join(WORKERS),
worker_options=", ".join([f'"{w}"' for w in WORKERS]),
example_worker_1=WORKERS[0] if WORKERS else "researcher",
example_worker_2=WORKERS[1] if len(WORKERS) > 1 else "coder"
)
# Get structured output from the planner model
plan = planner_llm.with_structured_output(Plan).invoke(formatted_messages)
# Return with updated state
return Command(
goto="supervisor",
update={
"plan": plan,
"current_step_index": 0,
# Add a message to show the plan was created
"messages": [
HumanMessage(
content=f"Created plan with {len(plan['steps'])} steps",
name="planner"
)
],
"steps_taken": steps_taken
}
)
# --- Final Answer node -----------------------------------------------------
def final_answer_node(state: State) -> Command[Literal["__end__"]]:
"""Generate a final answer based on gathered information.
Args:
state: The current state with messages and context
Returns:
Command with final answer
"""
configuration = Configuration.from_context()
# Track steps
steps_taken = state.get("steps_taken", 0)
steps_taken += 1
# Check if we've exhausted retries and already have a draft answer
retry_exhausted = state.get("retry_exhausted", False)
draft_answer = state.get("draft_answer")
# Variable to store the final answer
gaia_answer = ""
if retry_exhausted and draft_answer and draft_answer.startswith("FINAL ANSWER:"):
# If supervisor already provided a properly formatted answer after exhausting retries,
# use it directly without calling the model again
import re
final_answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", draft_answer, re.IGNORECASE)
if final_answer_match:
gaia_answer = final_answer_match.group(1).strip()
else:
gaia_answer = "unknown"
else:
# Use the specific final answer model
final_llm = load_chat_model(configuration.final_answer_model)
# Get the original user question (the latest user message)
user_messages = [m for m in state["messages"] if is_user_message(m)]
original_question = get_message_content(user_messages[-1]) if user_messages else "Help me"
# Check if we already have a draft answer from supervisor
if draft_answer and draft_answer.startswith("FINAL ANSWER:"):
# If supervisor already provided a properly formatted answer, use it directly
raw_answer = draft_answer
else:
# Get the context and worker results
context = state.get("context", {})
worker_results = state.get("worker_results", {})
# Compose a prompt for the final answer using the GAIA-specific format
final_prompt = ChatPromptTemplate.from_messages([
("system", prompts.FINAL_ANSWER_PROMPT),
("user", prompts.FINAL_ANSWER_USER_PROMPT)
])
# Format the context information more effectively
context_list = []
# First include researcher context as it provides background
if "researcher" in context:
context_list.append(f"Research information: {context['researcher']}")
# Then include coder results which are typically calculations
if "coder" in context:
context_list.append(f"Calculation results: {context['coder']}")
# Add any other workers
for worker, content in context.items():
if worker not in ["researcher", "coder"]:
context_list.append(f"{worker.capitalize()}: {content}")
# Get the final answer
formatted_messages = final_prompt.format_messages(
question=original_question,
context="\n\n".join(context_list)
)
raw_answer = final_llm.invoke(formatted_messages).content
# Extract the answer in GAIA format: "FINAL ANSWER: [x]"
import re
gaia_answer = raw_answer
final_answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", raw_answer, re.IGNORECASE)
if final_answer_match:
gaia_answer = final_answer_match.group(1).strip()
# Ensure answer is properly formatted - if we don't have a valid answer
# but have sufficient context, try to extract directly
if configuration.allow_agent_to_extract_answers and (not gaia_answer or gaia_answer.lower() in ["unknown", "insufficient information"]):
context = state.get("context", {})
from react_agent.supervisor_node import extract_best_answer_from_context
extracted_answer = extract_best_answer_from_context(context)
if extracted_answer != "unknown":
gaia_answer = extracted_answer
# Set status to "final_answer_generated" to indicate we're done
return Command(
goto=END,
update={
"messages": [
AIMessage(
content=f"FINAL ANSWER: {gaia_answer}",
name="supervisor"
)
],
"next": "FINISH", # Update next to indicate we're done
"gaia_answer": gaia_answer, # Store answer in GAIA-compatible format
"submitted_answer": gaia_answer, # Store as submitted_answer for GAIA benchmark
"status": "final_answer_generated", # Add status to indicate we're complete
"steps_taken": steps_taken
}
)
# --- Critic node ----------------------------------------------------------
def critic_node(state: State) -> Command[Union[WorkerDestination, SupervisorDestinations]]:
"""Critic that evaluates if the answer fully satisfies the request.
Args:
state: The current state with messages and draft answer
Returns:
Command with evaluation verdict
"""
configuration = Configuration.from_context()
# Use the specific critic model
critic_llm = load_chat_model(configuration.critic_model)
# Track steps
steps_taken = state.get("steps_taken", 0)
steps_taken += 1
# Get the original user question (the latest user message)
user_messages = [m for m in state["messages"] if is_user_message(m)]
original_question = get_message_content(user_messages[-1]) if user_messages else "Help me"
# Get the draft answer
draft_answer = state.get("draft_answer", "No answer provided.")
# Create a chat prompt template with proper formatting
critic_prompt_template = ChatPromptTemplate.from_messages([
("system", prompts.CRITIC_PROMPT),
("user", prompts.CRITIC_USER_PROMPT)
])
# Format the prompt with the necessary variables
formatted_messages = critic_prompt_template.format_messages(
question=original_question,
answer=draft_answer,
system_time=format_system_prompt("{system_time}"),
correct_verdict=VERDICTS[0] if VERDICTS else "CORRECT",
retry_verdict=VERDICTS[1] if len(VERDICTS) > 1 else "RETRY"
)
# Get structured output from the critic model
verdict = critic_llm.with_structured_output(CriticVerdict).invoke(formatted_messages)
# Add a message about the verdict
if verdict["verdict"] == VERDICTS[0]: # CORRECT
verdict_message = "Answer is complete, accurate, and properly formatted for GAIA."
goto = "final_answer" # Go to final answer node if correct
else:
verdict_message = f"Answer needs improvement. Reason: {verdict.get('reason', 'Unknown')}"
goto = "supervisor"
# Return with updated state
return Command(
goto=goto,
update={
"critic_verdict": verdict,
"messages": [
HumanMessage(
content=verdict_message,
name="critic"
)
],
"steps_taken": steps_taken
}
)
# --- Worker agent factory -------------------------------------------------
def create_worker_node(worker_type: str):
"""Factory function to create a worker node of the specified type.
Args:
worker_type: The type of worker to create (must be in WORKERS)
Returns:
A function that processes requests for the specified worker type
"""
if worker_type not in WORKERS:
raise ValueError(f"Unknown worker type: {worker_type}")
configuration = Configuration.from_context()
# Select the appropriate model for each worker type
if worker_type == "researcher":
llm = load_chat_model(configuration.researcher_model)
worker_prompt = prompts.RESEARCHER_PROMPT
worker_tools = [tavily_tool]
elif worker_type == "coder":
llm = load_chat_model(configuration.coder_model)
worker_prompt = prompts.CODER_PROMPT
worker_tools = [python_repl_tool]
else:
# Default case
llm = load_chat_model(configuration.model)
worker_prompt = getattr(prompts, f"{worker_type.upper()}_PROMPT", prompts.SYSTEM_PROMPT)
worker_tools = TOOLS
# Create the agent
worker_agent = create_react_agent(
llm,
tools=worker_tools,
prompt=format_system_prompt(worker_prompt)
)
# Define node function
def worker_node(state: State) -> Command[WorkerDestination]:
"""Process requests using the specified worker.
Args:
state: The current conversation state
Returns:
Command to return to supervisor with results
"""
# Track steps
steps_taken = state.get("steps_taken", 0)
steps_taken += 1
# Get the last message from the supervisor, which contains our task
task_message = None
if state.get("messages"):
for msg in reversed(state["messages"]):
if hasattr(msg, "name") and msg.name == "supervisor":
task_message = msg
break
if not task_message:
return Command(
goto="supervisor",
update={
"messages": [
HumanMessage(
content=f"Error: No task message found for {worker_type}",
name=worker_type
)
],
"steps_taken": steps_taken
}
)
# Create a new state with just the relevant messages for this worker
# This prevents confusion from unrelated parts of the conversation
agent_input = {
"messages": [
# Include the first user message for context
state["messages"][0] if state["messages"] else HumanMessage(content="Help me"),
# Include the task message
task_message
]
}
# Invoke the agent with the clean input
result = worker_agent.invoke(agent_input)
# Extract the result from the agent response
result_content = extract_worker_result(worker_type, result, state)
# Store the worker's result in shared context
context_update = state.get("context", {}).copy()
context_update[worker_type] = result_content
# Store in worker_results history
worker_results = state.get("worker_results", {}).copy()
if worker_type not in worker_results:
worker_results[worker_type] = []
worker_results[worker_type].append(result_content)
# Increment the step index after worker completes
current_step_index = state.get("current_step_index", 0)
return Command(
update={
"messages": [
HumanMessage(content=result_content, name=worker_type)
],
"current_step_index": current_step_index + 1,
"context": context_update,
"worker_results": worker_results,
"steps_taken": steps_taken
},
goto="supervisor",
)
return worker_node
def extract_worker_result(worker_type: str, result: dict, state: State) -> str:
"""Extract a clean, useful result from the worker's output.
This handles different response formats from different worker types.
Args:
worker_type: The type of worker (researcher or coder)
result: The raw result from the worker agent
state: The current state for context
Returns:
A cleaned string with the relevant result information
"""
# Handle empty results
if not result or "messages" not in result or not result["messages"]:
return f"No output from {worker_type}"
# Get the last message from the agent
last_message = result["messages"][-1]
# Default to extracting content directly
if hasattr(last_message, "content") and last_message.content:
result_content = last_message.content
else:
result_content = f"No content from {worker_type}"
# Special handling based on worker type
if worker_type == "coder":
# For coder outputs, extract the actual result values from code execution
if "```" in result_content:
# Try to extract stdout from code execution
import re
stdout_match = re.search(r"Stdout:\s*(.*?)(?:\n\n|$)", result_content, re.DOTALL)
if stdout_match:
# Extract the actual execution output, not just the code
execution_result = stdout_match.group(1).strip()
if execution_result:
# Check if this is just a simple number result
if re.match(r"^\d+(\.\d+)?$", execution_result):
return execution_result
else:
return f"Code executed with result: {execution_result}"
# If we couldn't find stdout, try to extract output in a different way
# Look for "Result:" or similar indicators
result_match = re.search(r"(?:Result|Output|Answer):\s*(.*?)(?:\n\n|$)", result_content, re.DOTALL)
if result_match:
return result_match.group(1).strip()
elif worker_type == "researcher":
# For researcher outputs, keep the full detailed response
# but ensure it's well-formatted
if len(result_content) > 800:
# If too long, try to extract key sections
# Look for summary or conclusion sections
import re
summary_match = re.search(r"(?:Summary|Conclusion|To summarize|In summary):(.*?)(?:\n\n|$)",
result_content, re.IGNORECASE | re.DOTALL)
if summary_match:
return summary_match.group(1).strip()
# If no special handling was triggered, return the content as is
return result_content
# --- Graph assembly -------------------------------------------------------
def create_agent_supervisor_graph() -> StateGraph:
"""Create the agent supervisor graph with all nodes and edges.
Returns:
Compiled StateGraph ready for execution
"""
# Initialize the graph with our State type
builder = StateGraph(State)
# Add control nodes
builder.add_node("planner", planner_node)
builder.add_node("supervisor", supervisor_node)
builder.add_node("critic", critic_node)
builder.add_node("final_answer", final_answer_node)
# Add worker nodes dynamically based on WORKERS list
for worker_type in WORKERS:
builder.add_node(worker_type, create_worker_node(worker_type))
# Define the workflow
builder.add_edge(START, "supervisor")
builder.add_edge("planner", "supervisor")
builder.add_edge("critic", "supervisor")
builder.add_edge("critic", "final_answer") # Add edge from critic to final_answer
builder.add_edge("final_answer", END) # Final answer node goes to END
builder.add_edge("supervisor", END) # Allow the supervisor to end the workflow
# Connect all workers to supervisor
for worker_type in WORKERS:
builder.add_edge(worker_type, "supervisor")
# Return the builder, not a compiled graph
# This allows the caller to compile with a checkpointer
return builder
# --- Graph instantiation (with flexible checkpointing) -----------------------------
def get_compiled_graph(checkpointer=None):
"""Get a compiled graph with optional checkpointer.
Args:
checkpointer: Optional checkpointer for persistence
Returns:
Compiled StateGraph ready for execution
"""
# Get configuration
configuration = Configuration.from_context()
builder = create_agent_supervisor_graph()
# Define termination condition function to prevent loops
def should_end(state):
"""Determine if the graph should terminate."""
# End if status is set to final_answer_generated
if state.get("status") == "final_answer_generated":
return True
# End if retry_exhausted flag is set and we've gone through final_answer
if state.get("retry_exhausted") and state.get("gaia_answer"):
return True
# End if we've hit maximum recursion limit defined by LangGraph
steps_taken = state.get("steps_taken", 0)
if steps_taken >= configuration.recursion_limit - 5: # Leave buffer
return True
return False
# Define step counter for tracking step count
def count_steps(state):
"""Count steps to prevent infinite loops."""
steps_taken = state.get("steps_taken", 0)
return {"steps_taken": steps_taken + 1}
# Compile the graph (don't use add_state_transform which isn't available)
if checkpointer:
graph = builder.compile(
checkpointer=checkpointer,
name="Structured Reasoning Loop"
)
else:
graph = builder.compile(
name="Structured Reasoning Loop"
)
# Configure the graph with recursion limit and max iterations
graph = graph.with_config({
"recursion_limit": configuration.recursion_limit,
"max_iterations": configuration.max_iterations
})
return graph
# Initialize a default non-checkpointed graph (for backward compatibility)
graph = get_compiled_graph()