|
"""Define an Agent Supervisor graph with specialized worker agents. |
|
|
|
The supervisor routes tasks to specialized agents based on the query type. |
|
""" |
|
|
|
from typing import Dict, List, Literal, Optional, Union, Type, cast |
|
|
|
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langgraph.graph import StateGraph, START, END |
|
|
|
from langgraph.prebuilt import create_react_agent |
|
from langgraph.types import Command |
|
|
|
from react_agent.configuration import Configuration |
|
from react_agent.state import WORKERS, MEMBERS, ROUTING, VERDICTS, State, Router, Plan, PlanStep, CriticVerdict |
|
from react_agent.tools import TOOLS, tavily_tool, python_repl_tool |
|
from react_agent.utils import load_chat_model, format_system_prompt, get_message_text |
|
from react_agent import prompts |
|
from react_agent.supervisor_node import supervisor_node |
|
|
|
|
|
|
|
SupervisorDestinations = Literal["planner", "critic", "researcher", "coder", "final_answer", "__end__"] |
|
WorkerDestination = Literal["supervisor"] |
|
|
|
|
|
|
|
def is_user_message(message): |
|
"""Check if a message is from a user regardless of message format.""" |
|
if isinstance(message, dict): |
|
return message.get("role") == "user" |
|
elif isinstance(message, HumanMessage): |
|
return True |
|
return False |
|
|
|
|
|
|
|
def get_message_content(message): |
|
"""Extract content from a message regardless of format.""" |
|
if isinstance(message, dict): |
|
return message.get("content", "") |
|
elif hasattr(message, "content"): |
|
return message.content |
|
return "" |
|
|
|
|
|
|
|
|
|
def planner_node(state: State) -> Command[WorkerDestination]: |
|
"""Planning LLM that creates a step-by-step execution plan. |
|
|
|
Args: |
|
state: The current state with messages |
|
|
|
Returns: |
|
Command to update the state with a plan |
|
""" |
|
configuration = Configuration.from_context() |
|
|
|
planner_llm = load_chat_model(configuration.planner_model) |
|
|
|
|
|
steps_taken = state.get("steps_taken", 0) |
|
steps_taken += 1 |
|
|
|
|
|
user_messages = [m for m in state["messages"] if is_user_message(m)] |
|
original_question = get_message_content(user_messages[-1]) if user_messages else "Help me" |
|
|
|
|
|
planner_prompt_template = ChatPromptTemplate.from_messages([ |
|
("system", prompts.PLANNER_PROMPT), |
|
("user", "{question}") |
|
]) |
|
|
|
|
|
formatted_messages = planner_prompt_template.format_messages( |
|
question=original_question, |
|
system_time=format_system_prompt("{system_time}"), |
|
workers=", ".join(WORKERS), |
|
worker_options=", ".join([f'"{w}"' for w in WORKERS]), |
|
example_worker_1=WORKERS[0] if WORKERS else "researcher", |
|
example_worker_2=WORKERS[1] if len(WORKERS) > 1 else "coder" |
|
) |
|
|
|
|
|
plan = planner_llm.with_structured_output(Plan).invoke(formatted_messages) |
|
|
|
|
|
return Command( |
|
goto="supervisor", |
|
update={ |
|
"plan": plan, |
|
"current_step_index": 0, |
|
|
|
"messages": [ |
|
HumanMessage( |
|
content=f"Created plan with {len(plan['steps'])} steps", |
|
name="planner" |
|
) |
|
], |
|
"steps_taken": steps_taken |
|
} |
|
) |
|
|
|
|
|
|
|
|
|
def final_answer_node(state: State) -> Command[Literal["__end__"]]: |
|
"""Generate a final answer based on gathered information. |
|
|
|
Args: |
|
state: The current state with messages and context |
|
|
|
Returns: |
|
Command with final answer |
|
""" |
|
configuration = Configuration.from_context() |
|
|
|
|
|
steps_taken = state.get("steps_taken", 0) |
|
steps_taken += 1 |
|
|
|
|
|
retry_exhausted = state.get("retry_exhausted", False) |
|
draft_answer = state.get("draft_answer") |
|
|
|
|
|
gaia_answer = "" |
|
|
|
if retry_exhausted and draft_answer and draft_answer.startswith("FINAL ANSWER:"): |
|
|
|
|
|
import re |
|
final_answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", draft_answer, re.IGNORECASE) |
|
if final_answer_match: |
|
gaia_answer = final_answer_match.group(1).strip() |
|
else: |
|
gaia_answer = "unknown" |
|
else: |
|
|
|
final_llm = load_chat_model(configuration.final_answer_model) |
|
|
|
|
|
user_messages = [m for m in state["messages"] if is_user_message(m)] |
|
original_question = get_message_content(user_messages[-1]) if user_messages else "Help me" |
|
|
|
|
|
if draft_answer and draft_answer.startswith("FINAL ANSWER:"): |
|
|
|
raw_answer = draft_answer |
|
else: |
|
|
|
context = state.get("context", {}) |
|
worker_results = state.get("worker_results", {}) |
|
|
|
|
|
final_prompt = ChatPromptTemplate.from_messages([ |
|
("system", prompts.FINAL_ANSWER_PROMPT), |
|
("user", prompts.FINAL_ANSWER_USER_PROMPT) |
|
]) |
|
|
|
|
|
context_list = [] |
|
|
|
if "researcher" in context: |
|
context_list.append(f"Research information: {context['researcher']}") |
|
|
|
|
|
if "coder" in context: |
|
context_list.append(f"Calculation results: {context['coder']}") |
|
|
|
|
|
for worker, content in context.items(): |
|
if worker not in ["researcher", "coder"]: |
|
context_list.append(f"{worker.capitalize()}: {content}") |
|
|
|
|
|
formatted_messages = final_prompt.format_messages( |
|
question=original_question, |
|
context="\n\n".join(context_list) |
|
) |
|
|
|
raw_answer = final_llm.invoke(formatted_messages).content |
|
|
|
|
|
import re |
|
gaia_answer = raw_answer |
|
final_answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", raw_answer, re.IGNORECASE) |
|
if final_answer_match: |
|
gaia_answer = final_answer_match.group(1).strip() |
|
|
|
|
|
|
|
if configuration.allow_agent_to_extract_answers and (not gaia_answer or gaia_answer.lower() in ["unknown", "insufficient information"]): |
|
context = state.get("context", {}) |
|
from react_agent.supervisor_node import extract_best_answer_from_context |
|
extracted_answer = extract_best_answer_from_context(context) |
|
if extracted_answer != "unknown": |
|
gaia_answer = extracted_answer |
|
|
|
|
|
return Command( |
|
goto=END, |
|
update={ |
|
"messages": [ |
|
AIMessage( |
|
content=f"FINAL ANSWER: {gaia_answer}", |
|
name="supervisor" |
|
) |
|
], |
|
"next": "FINISH", |
|
"gaia_answer": gaia_answer, |
|
"submitted_answer": gaia_answer, |
|
"status": "final_answer_generated", |
|
"steps_taken": steps_taken |
|
} |
|
) |
|
|
|
|
|
|
|
|
|
def critic_node(state: State) -> Command[Union[WorkerDestination, SupervisorDestinations]]: |
|
"""Critic that evaluates if the answer fully satisfies the request. |
|
|
|
Args: |
|
state: The current state with messages and draft answer |
|
|
|
Returns: |
|
Command with evaluation verdict |
|
""" |
|
configuration = Configuration.from_context() |
|
|
|
critic_llm = load_chat_model(configuration.critic_model) |
|
|
|
|
|
steps_taken = state.get("steps_taken", 0) |
|
steps_taken += 1 |
|
|
|
|
|
user_messages = [m for m in state["messages"] if is_user_message(m)] |
|
original_question = get_message_content(user_messages[-1]) if user_messages else "Help me" |
|
|
|
|
|
draft_answer = state.get("draft_answer", "No answer provided.") |
|
|
|
|
|
critic_prompt_template = ChatPromptTemplate.from_messages([ |
|
("system", prompts.CRITIC_PROMPT), |
|
("user", prompts.CRITIC_USER_PROMPT) |
|
]) |
|
|
|
|
|
formatted_messages = critic_prompt_template.format_messages( |
|
question=original_question, |
|
answer=draft_answer, |
|
system_time=format_system_prompt("{system_time}"), |
|
correct_verdict=VERDICTS[0] if VERDICTS else "CORRECT", |
|
retry_verdict=VERDICTS[1] if len(VERDICTS) > 1 else "RETRY" |
|
) |
|
|
|
|
|
verdict = critic_llm.with_structured_output(CriticVerdict).invoke(formatted_messages) |
|
|
|
|
|
if verdict["verdict"] == VERDICTS[0]: |
|
verdict_message = "Answer is complete, accurate, and properly formatted for GAIA." |
|
goto = "final_answer" |
|
else: |
|
verdict_message = f"Answer needs improvement. Reason: {verdict.get('reason', 'Unknown')}" |
|
goto = "supervisor" |
|
|
|
|
|
return Command( |
|
goto=goto, |
|
update={ |
|
"critic_verdict": verdict, |
|
"messages": [ |
|
HumanMessage( |
|
content=verdict_message, |
|
name="critic" |
|
) |
|
], |
|
"steps_taken": steps_taken |
|
} |
|
) |
|
|
|
|
|
|
|
|
|
def create_worker_node(worker_type: str): |
|
"""Factory function to create a worker node of the specified type. |
|
|
|
Args: |
|
worker_type: The type of worker to create (must be in WORKERS) |
|
|
|
Returns: |
|
A function that processes requests for the specified worker type |
|
""" |
|
if worker_type not in WORKERS: |
|
raise ValueError(f"Unknown worker type: {worker_type}") |
|
|
|
configuration = Configuration.from_context() |
|
|
|
|
|
if worker_type == "researcher": |
|
llm = load_chat_model(configuration.researcher_model) |
|
worker_prompt = prompts.RESEARCHER_PROMPT |
|
worker_tools = [tavily_tool] |
|
elif worker_type == "coder": |
|
llm = load_chat_model(configuration.coder_model) |
|
worker_prompt = prompts.CODER_PROMPT |
|
worker_tools = [python_repl_tool] |
|
else: |
|
|
|
llm = load_chat_model(configuration.model) |
|
worker_prompt = getattr(prompts, f"{worker_type.upper()}_PROMPT", prompts.SYSTEM_PROMPT) |
|
worker_tools = TOOLS |
|
|
|
|
|
worker_agent = create_react_agent( |
|
llm, |
|
tools=worker_tools, |
|
prompt=format_system_prompt(worker_prompt) |
|
) |
|
|
|
|
|
def worker_node(state: State) -> Command[WorkerDestination]: |
|
"""Process requests using the specified worker. |
|
|
|
Args: |
|
state: The current conversation state |
|
|
|
Returns: |
|
Command to return to supervisor with results |
|
""" |
|
|
|
steps_taken = state.get("steps_taken", 0) |
|
steps_taken += 1 |
|
|
|
|
|
task_message = None |
|
if state.get("messages"): |
|
for msg in reversed(state["messages"]): |
|
if hasattr(msg, "name") and msg.name == "supervisor": |
|
task_message = msg |
|
break |
|
|
|
if not task_message: |
|
return Command( |
|
goto="supervisor", |
|
update={ |
|
"messages": [ |
|
HumanMessage( |
|
content=f"Error: No task message found for {worker_type}", |
|
name=worker_type |
|
) |
|
], |
|
"steps_taken": steps_taken |
|
} |
|
) |
|
|
|
|
|
|
|
agent_input = { |
|
"messages": [ |
|
|
|
state["messages"][0] if state["messages"] else HumanMessage(content="Help me"), |
|
|
|
task_message |
|
] |
|
} |
|
|
|
|
|
result = worker_agent.invoke(agent_input) |
|
|
|
|
|
result_content = extract_worker_result(worker_type, result, state) |
|
|
|
|
|
context_update = state.get("context", {}).copy() |
|
context_update[worker_type] = result_content |
|
|
|
|
|
worker_results = state.get("worker_results", {}).copy() |
|
if worker_type not in worker_results: |
|
worker_results[worker_type] = [] |
|
worker_results[worker_type].append(result_content) |
|
|
|
|
|
current_step_index = state.get("current_step_index", 0) |
|
|
|
return Command( |
|
update={ |
|
"messages": [ |
|
HumanMessage(content=result_content, name=worker_type) |
|
], |
|
"current_step_index": current_step_index + 1, |
|
"context": context_update, |
|
"worker_results": worker_results, |
|
"steps_taken": steps_taken |
|
}, |
|
goto="supervisor", |
|
) |
|
|
|
return worker_node |
|
|
|
|
|
def extract_worker_result(worker_type: str, result: dict, state: State) -> str: |
|
"""Extract a clean, useful result from the worker's output. |
|
|
|
This handles different response formats from different worker types. |
|
|
|
Args: |
|
worker_type: The type of worker (researcher or coder) |
|
result: The raw result from the worker agent |
|
state: The current state for context |
|
|
|
Returns: |
|
A cleaned string with the relevant result information |
|
""" |
|
|
|
if not result or "messages" not in result or not result["messages"]: |
|
return f"No output from {worker_type}" |
|
|
|
|
|
last_message = result["messages"][-1] |
|
|
|
|
|
if hasattr(last_message, "content") and last_message.content: |
|
result_content = last_message.content |
|
else: |
|
result_content = f"No content from {worker_type}" |
|
|
|
|
|
if worker_type == "coder": |
|
|
|
if "```" in result_content: |
|
|
|
import re |
|
stdout_match = re.search(r"Stdout:\s*(.*?)(?:\n\n|$)", result_content, re.DOTALL) |
|
if stdout_match: |
|
|
|
execution_result = stdout_match.group(1).strip() |
|
if execution_result: |
|
|
|
if re.match(r"^\d+(\.\d+)?$", execution_result): |
|
return execution_result |
|
else: |
|
return f"Code executed with result: {execution_result}" |
|
|
|
|
|
|
|
result_match = re.search(r"(?:Result|Output|Answer):\s*(.*?)(?:\n\n|$)", result_content, re.DOTALL) |
|
if result_match: |
|
return result_match.group(1).strip() |
|
|
|
elif worker_type == "researcher": |
|
|
|
|
|
if len(result_content) > 800: |
|
|
|
|
|
import re |
|
summary_match = re.search(r"(?:Summary|Conclusion|To summarize|In summary):(.*?)(?:\n\n|$)", |
|
result_content, re.IGNORECASE | re.DOTALL) |
|
if summary_match: |
|
return summary_match.group(1).strip() |
|
|
|
|
|
return result_content |
|
|
|
|
|
|
|
|
|
def create_agent_supervisor_graph() -> StateGraph: |
|
"""Create the agent supervisor graph with all nodes and edges. |
|
|
|
Returns: |
|
Compiled StateGraph ready for execution |
|
""" |
|
|
|
builder = StateGraph(State) |
|
|
|
|
|
builder.add_node("planner", planner_node) |
|
builder.add_node("supervisor", supervisor_node) |
|
builder.add_node("critic", critic_node) |
|
builder.add_node("final_answer", final_answer_node) |
|
|
|
|
|
for worker_type in WORKERS: |
|
builder.add_node(worker_type, create_worker_node(worker_type)) |
|
|
|
|
|
builder.add_edge(START, "supervisor") |
|
builder.add_edge("planner", "supervisor") |
|
builder.add_edge("critic", "supervisor") |
|
builder.add_edge("critic", "final_answer") |
|
builder.add_edge("final_answer", END) |
|
builder.add_edge("supervisor", END) |
|
|
|
|
|
for worker_type in WORKERS: |
|
builder.add_edge(worker_type, "supervisor") |
|
|
|
|
|
|
|
return builder |
|
|
|
|
|
|
|
|
|
def get_compiled_graph(checkpointer=None): |
|
"""Get a compiled graph with optional checkpointer. |
|
|
|
Args: |
|
checkpointer: Optional checkpointer for persistence |
|
|
|
Returns: |
|
Compiled StateGraph ready for execution |
|
""" |
|
|
|
configuration = Configuration.from_context() |
|
|
|
builder = create_agent_supervisor_graph() |
|
|
|
|
|
def should_end(state): |
|
"""Determine if the graph should terminate.""" |
|
|
|
if state.get("status") == "final_answer_generated": |
|
return True |
|
|
|
|
|
if state.get("retry_exhausted") and state.get("gaia_answer"): |
|
return True |
|
|
|
|
|
steps_taken = state.get("steps_taken", 0) |
|
if steps_taken >= configuration.recursion_limit - 5: |
|
return True |
|
|
|
return False |
|
|
|
|
|
def count_steps(state): |
|
"""Count steps to prevent infinite loops.""" |
|
steps_taken = state.get("steps_taken", 0) |
|
return {"steps_taken": steps_taken + 1} |
|
|
|
|
|
if checkpointer: |
|
graph = builder.compile( |
|
checkpointer=checkpointer, |
|
name="Structured Reasoning Loop" |
|
) |
|
else: |
|
graph = builder.compile( |
|
name="Structured Reasoning Loop" |
|
) |
|
|
|
|
|
graph = graph.with_config({ |
|
"recursion_limit": configuration.recursion_limit, |
|
"max_iterations": configuration.max_iterations |
|
}) |
|
|
|
return graph |
|
|
|
|
|
|
|
graph = get_compiled_graph() |
|
|