GaiaAgent / supervisor_node.py
sims2k's picture
Upload 8 files
ac6a4ef verified
"""Supervisor node implementation for the agent supervisor system."""
from typing import Dict, List, Literal, Optional, Union, Type, cast
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph, START, END
from langgraph.types import Command
from react_agent.configuration import Configuration
from react_agent.state import WORKERS, MEMBERS, ROUTING, VERDICTS, State, Router
from react_agent.utils import load_chat_model, format_system_prompt, get_message_text
from react_agent import prompts
# Compile-time type definitions
SupervisorDestinations = Literal["planner", "critic", "researcher", "coder", "final_answer", "__end__"]
def supervisor_node(state: State) -> Command[SupervisorDestinations]:
"""Supervising LLM that decides which specialized agent should act next.
Args:
state: The current state with messages
Returns:
Command with routing information
"""
# Get configuration to use supervisor_model
configuration = Configuration.from_context()
# Track steps to prevent infinite loops
steps_taken = state.get("steps_taken", 0)
steps_taken += 1
state_updates = {"steps_taken": steps_taken}
# Check if we've hit our step limit
if steps_taken >= configuration.recursion_limit - 5: # Buffer of 5 steps
# Extract the best answer we have from context if possible
context = state.get("context", {})
answer = extract_best_answer_from_context(context)
return Command(
goto="final_answer",
update={
"messages": [
HumanMessage(
content=f"Maximum steps ({steps_taken}) reached. Extracting best answer from available information.",
name="supervisor"
)
],
"draft_answer": f"FINAL ANSWER: {answer}",
"retry_exhausted": True, # Flag to indicate we've exhausted retries
"steps_taken": steps_taken
}
)
# Safety check - prevent infinite loops by forcing termination after too many retry steps
retry_count = state.get("retry_count", 0)
max_retries = 2 # Maximum number of allowed retries
if retry_count > max_retries:
# Extract the best answer we have from context if possible
context = state.get("context", {})
answer = extract_best_answer_from_context(context)
return Command(
goto="final_answer",
update={
"messages": [
HumanMessage(
content=f"Maximum retries ({max_retries}) reached. Extracting best answer from available information.",
name="supervisor"
)
],
"draft_answer": f"FINAL ANSWER: {answer}",
"retry_exhausted": True, # Flag to indicate we've exhausted retries
"steps_taken": steps_taken
}
)
# Check if we need a plan
if not state.get("plan"):
return Command(
goto="planner",
update={
**state_updates
}
)
# Validate that the plan has at least one step
plan = state.get("plan")
if not plan.get("steps") or len(plan.get("steps", [])) == 0:
# Plan has no steps, go back to planner with explicit instructions
return Command(
goto="planner",
update={
"messages": [
HumanMessage(
content="Previous plan had 0 steps. Please create a plan with at least 1 step to solve the user's question.",
name="supervisor"
)
],
"plan": None,
**state_updates
}
)
# Check if we have a critic verdict that requires replanning
critic_verdict = state.get("critic_verdict")
if critic_verdict:
if critic_verdict.get("verdict") == VERDICTS[0]: # CORRECT
# Final answer is approved, navigate to the final_answer node
# This will generate a polished response before ending
return Command(
goto="final_answer",
update={
"messages": [
HumanMessage(
content="Answer approved by critic. Generating final response.",
name="supervisor"
)
]
}
)
elif critic_verdict.get("verdict") == VERDICTS[1]: # RETRY
# IMPORTANT: Get the current retry count BEFORE incrementing
current_retry_count = state.get("retry_count", 0)
# Check if we're at the maximum allowed retries
if current_retry_count >= max_retries:
# Extract best answer and go to final_answer
context = state.get("context", {})
answer = extract_best_answer_from_context(context)
return Command(
goto="final_answer",
update={
"messages": [
HumanMessage(
content=f"Maximum retries ({max_retries}) reached. Proceeding with best available answer.",
name="supervisor"
)
],
"draft_answer": f"FINAL ANSWER: {answer}",
"retry_exhausted": True # Flag to indicate we've exhausted retries
}
)
# Reset the plan but KEEP the context from previous iterations
context = state.get("context", {})
worker_results = state.get("worker_results", {})
# Get the critic's reason for rejection, if any
reason = critic_verdict.get("reason", "")
if not reason or reason.strip() == "\"":
reason = "Answer did not meet format requirements"
# Check if this is a formatting issue
format_issues = [
"format", "concise", "explanation", "not formatted",
"instead of just", "contains explanations", "FINAL ANSWER"
]
is_format_issue = any(issue in reason.lower() for issue in format_issues)
# If we have enough information but the format is wrong, go directly to final answer
has_sufficient_info = has_sufficient_information(state)
if is_format_issue and has_sufficient_info and current_retry_count >= 0:
# We have information but formatting is wrong - skip planning and go to final answer
return Command(
goto="final_answer",
update={
"messages": [
HumanMessage(
content="We have sufficient information but formatting issues. Generating properly formatted answer.",
name="supervisor"
)
],
"retry_count": current_retry_count + 1 # Still increment retry count
}
)
# Increment the retry counter
next_retry_count = current_retry_count + 1
return Command(
goto="planner",
update={
"plan": None,
"current_step_index": None,
"draft_answer": None,
"critic_verdict": None,
# Keep the context and worker_results
"context": context,
"worker_results": worker_results,
# Track retries - IMPORTANT: store the incremented count
"retry_count": next_retry_count,
# Add a message about the retry (using the INCREMENTED count)
"messages": [
HumanMessage(
content=f"Retrying with new plan (retry #{next_retry_count}). Reason: {reason}",
name="supervisor"
)
]
}
)
# Get the current step from the plan
plan = state["plan"]
current_step_index = state.get("current_step_index", 0)
# Check if we've completed all steps
if current_step_index >= len(plan["steps"]):
# Use context to compile the draft answer
context = state.get("context", {})
# Combine the most recent worker outputs as the draft answer
worker_results = []
for worker in WORKERS:
if worker in context:
worker_results.append(f"**{worker.title()}**: {context[worker]}")
# Compile the draft answer from all worker outputs
draft_content = "\n\n".join(worker_results)
# Send to the critic for evaluation
return Command(
goto="critic",
update={
"draft_answer": draft_content,
# Add a message about moving to evaluation
"messages": [
HumanMessage(
content="All steps completed. Evaluating the answer.",
name="supervisor"
)
]
}
)
# Get the current step
current_step = plan["steps"][current_step_index]
worker = current_step["worker"]
instruction = current_step["instruction"]
# Extract only the most relevant context for the current worker and task
context_info = ""
if state.get("context"):
# Filter context by relevance to the current task
relevant_context = {}
# For the coder, extract numerical data and parameters from researcher
if worker == "coder" and "researcher" in state["context"]:
relevant_context["researcher"] = state["context"]["researcher"]
# For the researcher, previous coder calculations might be relevant
if worker == "researcher" and "coder" in state["context"]:
# Only include numerical results from coder, not code snippets
coder_content = state["context"]["coder"]
if len(coder_content) < 100: # Only short results are likely just numbers
relevant_context["coder"] = coder_content
# Format the relevant context items
context_items = []
for key, value in relevant_context.items():
# Summarize if value is too long
if len(value) > 200:
# Find first sentence or up to 200 chars
summary = value[:200]
if '.' in summary:
summary = summary.split('.')[0] + '.'
context_items.append(f"Previous {key} found: {summary}...")
else:
context_items.append(f"Previous {key} found: {value}")
if context_items:
context_info = "\n\nRelevant context: " + "\n".join(context_items)
# Enhance the instruction with context
enhanced_instruction = f"{instruction}{context_info}"
# Add guidance based on worker type
if worker == "coder":
enhanced_instruction += "\nProvide both your calculation method AND the final result value."
elif worker == "researcher":
enhanced_instruction += "\nFocus on gathering factual information related to the task."
# Add the instruction to the messages
messages_update = [
HumanMessage(
content=f"Step {current_step_index + 1}: {enhanced_instruction}",
name="supervisor"
)
]
# Cast worker to appropriate type to satisfy type checking
worker_destination = cast(SupervisorDestinations, worker)
# Move to the appropriate worker
return Command(
goto=worker_destination,
update={
"messages": messages_update,
"next": worker, # For backward compatibility
**state_updates
}
)
def extract_best_answer_from_context(context):
"""Extract the best available answer from context.
This is a generic function to extract answers from any type of question context.
It progressively tries different strategies to find a suitable answer.
Args:
context: The state context containing worker outputs
Returns:
Best answer found or "unknown" if nothing suitable is found
"""
answer = "unknown"
# First check if the coder already provided a properly formatted answer
if "coder" in context:
coder_content = context["coder"]
# Look for "FINAL ANSWER: X" pattern in the coder output
import re
answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", coder_content, re.IGNORECASE)
if answer_match:
return answer_match.group(1).strip()
# If no answer in coder output, check researcher content
if "researcher" in context:
researcher_content = context["researcher"]
# Look for lists in the researcher content (common pattern)
import re
# Look for bulleted list items
list_items = re.findall(r"[-•*]\s+([^:\n]+)", researcher_content)
if list_items:
# Format as comma-separated list
answer = ",".join(item.strip() for item in list_items)
return answer
# Look for emphasized/bold items which might be key information
bold_items = re.findall(r"\*\*([^*]+)\*\*", researcher_content)
if bold_items:
# Join the important items as a comma-separated list
processed_items = []
for item in bold_items:
# Remove common filler words and clean up the item
clean_item = re.sub(r'(^|\s)(a|an|the|is|are|was|were|be|been)(\s|$)', ' ', item)
clean_item = clean_item.strip()
if clean_item and len(clean_item) < 30: # Only include reasonably short items
processed_items.append(clean_item)
if processed_items:
answer = ",".join(processed_items)
return answer
# If we still don't have an answer, try to extract common entities
combined_content = ""
for worker_type, content in context.items():
combined_content += " " + content
# Look for numbers in the content
import re
numbers = re.findall(r'\b(\d+(?:\.\d+)?)\b', combined_content)
if numbers:
answer = numbers[0] # Use the first number found
return answer
def has_sufficient_information(state):
"""Determine if we have enough information to generate a final answer.
Args:
state: The current conversation state
Returns:
Boolean indicating if we have sufficient information
"""
context = state.get("context", {})
# If we have both researcher and coder outputs, we likely have enough info
if "researcher" in context and "coder" in context:
return True
# If we have a substantial researcher output, that might be enough
if "researcher" in context and len(context["researcher"]) > 150:
return True
# If we have any worker output that contains lists or formatted data
for worker, content in context.items():
if content and (
"- " in content or # Bullet point
"•" in content or # Bullet point
"*" in content or # Emphasis or bullet
":" in content # Definition or explanation
):
return True
return False