|
"""Supervisor node implementation for the agent supervisor system.""" |
|
|
|
from typing import Dict, List, Literal, Optional, Union, Type, cast |
|
|
|
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langgraph.graph import StateGraph, START, END |
|
from langgraph.types import Command |
|
|
|
from react_agent.configuration import Configuration |
|
from react_agent.state import WORKERS, MEMBERS, ROUTING, VERDICTS, State, Router |
|
from react_agent.utils import load_chat_model, format_system_prompt, get_message_text |
|
from react_agent import prompts |
|
|
|
|
|
|
|
SupervisorDestinations = Literal["planner", "critic", "researcher", "coder", "final_answer", "__end__"] |
|
|
|
|
|
def supervisor_node(state: State) -> Command[SupervisorDestinations]: |
|
"""Supervising LLM that decides which specialized agent should act next. |
|
|
|
Args: |
|
state: The current state with messages |
|
|
|
Returns: |
|
Command with routing information |
|
""" |
|
|
|
configuration = Configuration.from_context() |
|
|
|
|
|
steps_taken = state.get("steps_taken", 0) |
|
steps_taken += 1 |
|
state_updates = {"steps_taken": steps_taken} |
|
|
|
|
|
if steps_taken >= configuration.recursion_limit - 5: |
|
|
|
context = state.get("context", {}) |
|
answer = extract_best_answer_from_context(context) |
|
|
|
return Command( |
|
goto="final_answer", |
|
update={ |
|
"messages": [ |
|
HumanMessage( |
|
content=f"Maximum steps ({steps_taken}) reached. Extracting best answer from available information.", |
|
name="supervisor" |
|
) |
|
], |
|
"draft_answer": f"FINAL ANSWER: {answer}", |
|
"retry_exhausted": True, |
|
"steps_taken": steps_taken |
|
} |
|
) |
|
|
|
|
|
retry_count = state.get("retry_count", 0) |
|
max_retries = 2 |
|
|
|
if retry_count > max_retries: |
|
|
|
context = state.get("context", {}) |
|
answer = extract_best_answer_from_context(context) |
|
|
|
return Command( |
|
goto="final_answer", |
|
update={ |
|
"messages": [ |
|
HumanMessage( |
|
content=f"Maximum retries ({max_retries}) reached. Extracting best answer from available information.", |
|
name="supervisor" |
|
) |
|
], |
|
"draft_answer": f"FINAL ANSWER: {answer}", |
|
"retry_exhausted": True, |
|
"steps_taken": steps_taken |
|
} |
|
) |
|
|
|
|
|
if not state.get("plan"): |
|
return Command( |
|
goto="planner", |
|
update={ |
|
**state_updates |
|
} |
|
) |
|
|
|
|
|
plan = state.get("plan") |
|
if not plan.get("steps") or len(plan.get("steps", [])) == 0: |
|
|
|
return Command( |
|
goto="planner", |
|
update={ |
|
"messages": [ |
|
HumanMessage( |
|
content="Previous plan had 0 steps. Please create a plan with at least 1 step to solve the user's question.", |
|
name="supervisor" |
|
) |
|
], |
|
"plan": None, |
|
**state_updates |
|
} |
|
) |
|
|
|
|
|
critic_verdict = state.get("critic_verdict") |
|
if critic_verdict: |
|
if critic_verdict.get("verdict") == VERDICTS[0]: |
|
|
|
|
|
return Command( |
|
goto="final_answer", |
|
update={ |
|
"messages": [ |
|
HumanMessage( |
|
content="Answer approved by critic. Generating final response.", |
|
name="supervisor" |
|
) |
|
] |
|
} |
|
) |
|
elif critic_verdict.get("verdict") == VERDICTS[1]: |
|
|
|
current_retry_count = state.get("retry_count", 0) |
|
|
|
|
|
if current_retry_count >= max_retries: |
|
|
|
context = state.get("context", {}) |
|
answer = extract_best_answer_from_context(context) |
|
|
|
return Command( |
|
goto="final_answer", |
|
update={ |
|
"messages": [ |
|
HumanMessage( |
|
content=f"Maximum retries ({max_retries}) reached. Proceeding with best available answer.", |
|
name="supervisor" |
|
) |
|
], |
|
"draft_answer": f"FINAL ANSWER: {answer}", |
|
"retry_exhausted": True |
|
} |
|
) |
|
|
|
|
|
context = state.get("context", {}) |
|
worker_results = state.get("worker_results", {}) |
|
|
|
|
|
reason = critic_verdict.get("reason", "") |
|
if not reason or reason.strip() == "\"": |
|
reason = "Answer did not meet format requirements" |
|
|
|
|
|
format_issues = [ |
|
"format", "concise", "explanation", "not formatted", |
|
"instead of just", "contains explanations", "FINAL ANSWER" |
|
] |
|
is_format_issue = any(issue in reason.lower() for issue in format_issues) |
|
|
|
|
|
has_sufficient_info = has_sufficient_information(state) |
|
|
|
if is_format_issue and has_sufficient_info and current_retry_count >= 0: |
|
|
|
return Command( |
|
goto="final_answer", |
|
update={ |
|
"messages": [ |
|
HumanMessage( |
|
content="We have sufficient information but formatting issues. Generating properly formatted answer.", |
|
name="supervisor" |
|
) |
|
], |
|
"retry_count": current_retry_count + 1 |
|
} |
|
) |
|
|
|
|
|
next_retry_count = current_retry_count + 1 |
|
|
|
return Command( |
|
goto="planner", |
|
update={ |
|
"plan": None, |
|
"current_step_index": None, |
|
"draft_answer": None, |
|
"critic_verdict": None, |
|
|
|
"context": context, |
|
"worker_results": worker_results, |
|
|
|
"retry_count": next_retry_count, |
|
|
|
"messages": [ |
|
HumanMessage( |
|
content=f"Retrying with new plan (retry #{next_retry_count}). Reason: {reason}", |
|
name="supervisor" |
|
) |
|
] |
|
} |
|
) |
|
|
|
|
|
plan = state["plan"] |
|
current_step_index = state.get("current_step_index", 0) |
|
|
|
|
|
if current_step_index >= len(plan["steps"]): |
|
|
|
context = state.get("context", {}) |
|
|
|
|
|
worker_results = [] |
|
for worker in WORKERS: |
|
if worker in context: |
|
worker_results.append(f"**{worker.title()}**: {context[worker]}") |
|
|
|
|
|
draft_content = "\n\n".join(worker_results) |
|
|
|
|
|
return Command( |
|
goto="critic", |
|
update={ |
|
"draft_answer": draft_content, |
|
|
|
"messages": [ |
|
HumanMessage( |
|
content="All steps completed. Evaluating the answer.", |
|
name="supervisor" |
|
) |
|
] |
|
} |
|
) |
|
|
|
|
|
current_step = plan["steps"][current_step_index] |
|
worker = current_step["worker"] |
|
instruction = current_step["instruction"] |
|
|
|
|
|
context_info = "" |
|
if state.get("context"): |
|
|
|
relevant_context = {} |
|
|
|
|
|
if worker == "coder" and "researcher" in state["context"]: |
|
relevant_context["researcher"] = state["context"]["researcher"] |
|
|
|
|
|
if worker == "researcher" and "coder" in state["context"]: |
|
|
|
coder_content = state["context"]["coder"] |
|
if len(coder_content) < 100: |
|
relevant_context["coder"] = coder_content |
|
|
|
|
|
context_items = [] |
|
for key, value in relevant_context.items(): |
|
|
|
if len(value) > 200: |
|
|
|
summary = value[:200] |
|
if '.' in summary: |
|
summary = summary.split('.')[0] + '.' |
|
context_items.append(f"Previous {key} found: {summary}...") |
|
else: |
|
context_items.append(f"Previous {key} found: {value}") |
|
|
|
if context_items: |
|
context_info = "\n\nRelevant context: " + "\n".join(context_items) |
|
|
|
|
|
enhanced_instruction = f"{instruction}{context_info}" |
|
|
|
|
|
if worker == "coder": |
|
enhanced_instruction += "\nProvide both your calculation method AND the final result value." |
|
elif worker == "researcher": |
|
enhanced_instruction += "\nFocus on gathering factual information related to the task." |
|
|
|
|
|
messages_update = [ |
|
HumanMessage( |
|
content=f"Step {current_step_index + 1}: {enhanced_instruction}", |
|
name="supervisor" |
|
) |
|
] |
|
|
|
|
|
worker_destination = cast(SupervisorDestinations, worker) |
|
|
|
|
|
return Command( |
|
goto=worker_destination, |
|
update={ |
|
"messages": messages_update, |
|
"next": worker, |
|
**state_updates |
|
} |
|
) |
|
|
|
def extract_best_answer_from_context(context): |
|
"""Extract the best available answer from context. |
|
|
|
This is a generic function to extract answers from any type of question context. |
|
It progressively tries different strategies to find a suitable answer. |
|
|
|
Args: |
|
context: The state context containing worker outputs |
|
|
|
Returns: |
|
Best answer found or "unknown" if nothing suitable is found |
|
""" |
|
answer = "unknown" |
|
|
|
|
|
if "coder" in context: |
|
coder_content = context["coder"] |
|
|
|
|
|
import re |
|
answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", coder_content, re.IGNORECASE) |
|
if answer_match: |
|
return answer_match.group(1).strip() |
|
|
|
|
|
if "researcher" in context: |
|
researcher_content = context["researcher"] |
|
|
|
|
|
import re |
|
|
|
|
|
list_items = re.findall(r"[-•*]\s+([^:\n]+)", researcher_content) |
|
if list_items: |
|
|
|
answer = ",".join(item.strip() for item in list_items) |
|
return answer |
|
|
|
|
|
bold_items = re.findall(r"\*\*([^*]+)\*\*", researcher_content) |
|
if bold_items: |
|
|
|
processed_items = [] |
|
for item in bold_items: |
|
|
|
clean_item = re.sub(r'(^|\s)(a|an|the|is|are|was|were|be|been)(\s|$)', ' ', item) |
|
clean_item = clean_item.strip() |
|
if clean_item and len(clean_item) < 30: |
|
processed_items.append(clean_item) |
|
|
|
if processed_items: |
|
answer = ",".join(processed_items) |
|
return answer |
|
|
|
|
|
combined_content = "" |
|
for worker_type, content in context.items(): |
|
combined_content += " " + content |
|
|
|
|
|
import re |
|
numbers = re.findall(r'\b(\d+(?:\.\d+)?)\b', combined_content) |
|
if numbers: |
|
answer = numbers[0] |
|
|
|
return answer |
|
|
|
def has_sufficient_information(state): |
|
"""Determine if we have enough information to generate a final answer. |
|
|
|
Args: |
|
state: The current conversation state |
|
|
|
Returns: |
|
Boolean indicating if we have sufficient information |
|
""" |
|
context = state.get("context", {}) |
|
|
|
|
|
if "researcher" in context and "coder" in context: |
|
return True |
|
|
|
|
|
if "researcher" in context and len(context["researcher"]) > 150: |
|
return True |
|
|
|
|
|
for worker, content in context.items(): |
|
if content and ( |
|
"- " in content or |
|
"•" in content or |
|
"*" in content or |
|
":" in content |
|
): |
|
return True |
|
|
|
return False |