Upload 8 files
Browse files- app.py +375 -0
- configuration.py +166 -0
- graph.py +574 -0
- prompts.py +180 -0
- state.py +72 -0
- supervisor_node.py +406 -0
- tools.py +58 -0
- utils.py +76 -0
app.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Web application for the Agent Supervisor with GAIA benchmark integration.
|
2 |
+
|
3 |
+
This module provides a Gradio web interface for interacting with the Agent Supervisor
|
4 |
+
and evaluating it against the GAIA benchmark.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import json
|
9 |
+
import uuid
|
10 |
+
import asyncio
|
11 |
+
import requests
|
12 |
+
import pandas as pd
|
13 |
+
import gradio as gr
|
14 |
+
|
15 |
+
from typing import Dict, List, Optional
|
16 |
+
from langchain_core.messages import HumanMessage
|
17 |
+
from langgraph.checkpoint.memory import MemorySaver
|
18 |
+
|
19 |
+
from react_agent.graph import create_agent_supervisor_graph, get_compiled_graph
|
20 |
+
|
21 |
+
# --- Constants ---
|
22 |
+
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
23 |
+
|
24 |
+
class GaiaAgent:
|
25 |
+
"""Agent implementation for the GAIA benchmark using the LangGraph supervisor."""
|
26 |
+
|
27 |
+
def __init__(self, model_name=None, checkpointer=None):
|
28 |
+
"""Initialize the GAIA agent with LangGraph architecture.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
model_name: Optional model name to override the default
|
32 |
+
checkpointer: Optional checkpointer for persistence
|
33 |
+
"""
|
34 |
+
print("Initializing GaiaAgent...")
|
35 |
+
|
36 |
+
# Import Configuration class
|
37 |
+
from react_agent.configuration import Configuration
|
38 |
+
|
39 |
+
# Get configuration
|
40 |
+
config = Configuration.from_context()
|
41 |
+
default_model = config.model
|
42 |
+
|
43 |
+
# If no checkpointer provided, create a default one - using MemorySaver to avoid SQLite thread issues
|
44 |
+
if checkpointer is None:
|
45 |
+
from langgraph.checkpoint.memory import MemorySaver
|
46 |
+
checkpointer = MemorySaver()
|
47 |
+
print("Using in-memory checkpointer to avoid thread safety issues")
|
48 |
+
|
49 |
+
# Create and compile the graph
|
50 |
+
self.graph = get_compiled_graph(checkpointer=checkpointer)
|
51 |
+
|
52 |
+
# Configure the agent using values from Configuration
|
53 |
+
self.config = {
|
54 |
+
"configurable": {
|
55 |
+
# Use configuration model or override if provided
|
56 |
+
"model": model_name if model_name else default_model,
|
57 |
+
# Import specific models for each role from Configuration
|
58 |
+
"researcher_model": config.researcher_model,
|
59 |
+
"coder_model": config.coder_model,
|
60 |
+
"planner_model": config.planner_model,
|
61 |
+
"supervisor_model": config.supervisor_model,
|
62 |
+
"critic_model": config.critic_model,
|
63 |
+
"final_answer_model": config.final_answer_model,
|
64 |
+
# Other settings from Configuration
|
65 |
+
"max_search_results": config.max_search_results,
|
66 |
+
"recursion_limit": config.recursion_limit,
|
67 |
+
"max_iterations": config.max_iterations,
|
68 |
+
"allow_agent_to_extract_answers": config.allow_agent_to_extract_answers
|
69 |
+
}
|
70 |
+
}
|
71 |
+
|
72 |
+
print(f"GaiaAgent initialized successfully with model: {self.config['configurable']['model']}")
|
73 |
+
|
74 |
+
def __call__(self, question: str) -> str:
|
75 |
+
"""Process a question and return an answer formatted for GAIA benchmark.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
question: The GAIA benchmark question
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
Answer formatted for GAIA benchmark evaluation
|
82 |
+
"""
|
83 |
+
print(f"Agent received question: {question[:100]}...")
|
84 |
+
|
85 |
+
# Create a thread_id for this interaction
|
86 |
+
thread_id = str(uuid.uuid4())
|
87 |
+
self.config["configurable"]["thread_id"] = thread_id
|
88 |
+
|
89 |
+
# Import configuration
|
90 |
+
from react_agent.configuration import Configuration
|
91 |
+
config = Configuration.from_context()
|
92 |
+
|
93 |
+
# Add a system prompt to ensure proper GAIA format
|
94 |
+
system_prompt = """You are a general AI assistant. Answer the question concisely.
|
95 |
+
YOUR ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
|
96 |
+
If asked for a number, don't use commas or units like $ or % unless specified.
|
97 |
+
If asked for a string, don't use articles or abbreviations (e.g. for cities), and write digits as plain text unless specified otherwise.
|
98 |
+
Focus on brevity and correctness."""
|
99 |
+
|
100 |
+
# Create input state with the human message and system prompt
|
101 |
+
input_state = {
|
102 |
+
"messages": [HumanMessage(content=question)],
|
103 |
+
"configurable": {
|
104 |
+
"thread_id": thread_id,
|
105 |
+
"system_prompt": system_prompt,
|
106 |
+
"model": config.model # Ensure model is also set in the state
|
107 |
+
}
|
108 |
+
}
|
109 |
+
|
110 |
+
# Process the question with our graph
|
111 |
+
try:
|
112 |
+
# Execute the graph and get the final state
|
113 |
+
# Use invoke instead of stream to limit operations
|
114 |
+
try:
|
115 |
+
final_state = self.graph.invoke(input_state, config=self.config)
|
116 |
+
except Exception as e:
|
117 |
+
# If we hit recursion error, try again with higher limit
|
118 |
+
print(f"Initial invocation failed: {str(e)}")
|
119 |
+
# Use double the recursion limit as fallback
|
120 |
+
self.config["configurable"]["recursion_limit"] = config.recursion_limit * 2
|
121 |
+
final_state = self.graph.invoke(input_state, config=self.config)
|
122 |
+
|
123 |
+
# Extract the answer - either from gaia_answer or from the last message
|
124 |
+
if "gaia_answer" in final_state:
|
125 |
+
answer = final_state["gaia_answer"]
|
126 |
+
else:
|
127 |
+
messages = final_state.get("messages", [])
|
128 |
+
answer = messages[-1].content if messages else "No answer generated."
|
129 |
+
|
130 |
+
# Clean the answer to ensure proper GAIA format (remove any FINAL ANSWER prefix)
|
131 |
+
if "FINAL ANSWER:" in answer:
|
132 |
+
answer = answer.split("FINAL ANSWER:")[1].strip()
|
133 |
+
|
134 |
+
print(f"Agent returning answer: {answer[:100]}...")
|
135 |
+
return answer
|
136 |
+
|
137 |
+
except Exception as e:
|
138 |
+
error_msg = f"Error processing question: {str(e)}"
|
139 |
+
print(error_msg)
|
140 |
+
return error_msg
|
141 |
+
|
142 |
+
|
143 |
+
def run_and_submit_all(profile: gr.OAuthProfile | None):
|
144 |
+
"""Fetches all questions, runs the GaiaAgent on them, submits answers, and displays the results."""
|
145 |
+
|
146 |
+
# --- Determine HF Space Runtime URL and Repo URL ---
|
147 |
+
space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
|
148 |
+
|
149 |
+
if profile:
|
150 |
+
username = f"{profile.username}"
|
151 |
+
print(f"User logged in: {username}")
|
152 |
+
else:
|
153 |
+
print("User not logged in.")
|
154 |
+
return "Please Login to Hugging Face with the button.", None
|
155 |
+
|
156 |
+
api_url = DEFAULT_API_URL
|
157 |
+
questions_url = f"{api_url}/questions"
|
158 |
+
submit_url = f"{api_url}/submit"
|
159 |
+
|
160 |
+
# 1. Instantiate Agent
|
161 |
+
try:
|
162 |
+
agent = GaiaAgent()
|
163 |
+
except Exception as e:
|
164 |
+
print(f"Error instantiating agent: {e}")
|
165 |
+
return f"Error initializing agent: {e}", None
|
166 |
+
|
167 |
+
# In the case of an app running as a hugging Face space, this link points toward your codebase
|
168 |
+
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
|
169 |
+
print(agent_code)
|
170 |
+
|
171 |
+
# 2. Fetch Questions
|
172 |
+
print(f"Fetching questions from: {questions_url}")
|
173 |
+
try:
|
174 |
+
response = requests.get(questions_url, timeout=15)
|
175 |
+
response.raise_for_status()
|
176 |
+
questions_data = response.json()
|
177 |
+
if not questions_data:
|
178 |
+
print("Fetched questions list is empty.")
|
179 |
+
return "Fetched questions list is empty or invalid format.", None
|
180 |
+
print(f"Fetched {len(questions_data)} questions.")
|
181 |
+
except requests.exceptions.RequestException as e:
|
182 |
+
print(f"Error fetching questions: {e}")
|
183 |
+
return f"Error fetching questions: {e}", None
|
184 |
+
except requests.exceptions.JSONDecodeError as e:
|
185 |
+
print(f"Error decoding JSON response from questions endpoint: {e}")
|
186 |
+
print(f"Response text: {response.text[:500]}")
|
187 |
+
return f"Error decoding server response for questions: {e}", None
|
188 |
+
except Exception as e:
|
189 |
+
print(f"An unexpected error occurred fetching questions: {e}")
|
190 |
+
return f"An unexpected error occurred fetching questions: {e}", None
|
191 |
+
|
192 |
+
# 3. Run the Agent
|
193 |
+
results_log = []
|
194 |
+
answers_payload = []
|
195 |
+
print(f"Running agent on {len(questions_data)} questions...")
|
196 |
+
for item in questions_data:
|
197 |
+
task_id = item.get("task_id")
|
198 |
+
question_text = item.get("question")
|
199 |
+
if not task_id or question_text is None:
|
200 |
+
print(f"Skipping item with missing task_id or question: {item}")
|
201 |
+
continue
|
202 |
+
try:
|
203 |
+
answer = agent(question_text)
|
204 |
+
# Format answers according to API requirements - use submitted_answer as required
|
205 |
+
answers_payload.append({
|
206 |
+
"task_id": task_id,
|
207 |
+
"submitted_answer": answer
|
208 |
+
})
|
209 |
+
results_log.append({
|
210 |
+
"Task ID": task_id,
|
211 |
+
"Question": question_text,
|
212 |
+
"Answer": answer
|
213 |
+
})
|
214 |
+
except Exception as e:
|
215 |
+
print(f"Error running agent on task {task_id}: {e}")
|
216 |
+
results_log.append({
|
217 |
+
"Task ID": task_id,
|
218 |
+
"Question": question_text,
|
219 |
+
"Answer": f"AGENT ERROR: {e}"
|
220 |
+
})
|
221 |
+
|
222 |
+
if not answers_payload:
|
223 |
+
print("Agent did not produce any answers to submit.")
|
224 |
+
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
|
225 |
+
|
226 |
+
# 4. Prepare Submission
|
227 |
+
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
|
228 |
+
status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
|
229 |
+
print(status_update)
|
230 |
+
|
231 |
+
# 5. Submit
|
232 |
+
print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
|
233 |
+
try:
|
234 |
+
response = requests.post(submit_url, json=submission_data, timeout=60)
|
235 |
+
response.raise_for_status()
|
236 |
+
result_data = response.json()
|
237 |
+
final_status = (
|
238 |
+
f"Submission Successful!\n"
|
239 |
+
f"User: {result_data.get('username')}\n"
|
240 |
+
f"Overall Score: {result_data.get('score', 'N/A')}% "
|
241 |
+
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
|
242 |
+
f"Message: {result_data.get('message', 'No message received.')}"
|
243 |
+
)
|
244 |
+
print("Submission successful.")
|
245 |
+
results_df = pd.DataFrame(results_log)
|
246 |
+
return final_status, results_df
|
247 |
+
except requests.exceptions.HTTPError as e:
|
248 |
+
error_detail = f"Server responded with status {e.response.status_code}."
|
249 |
+
try:
|
250 |
+
error_json = e.response.json()
|
251 |
+
error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
|
252 |
+
except requests.exceptions.JSONDecodeError:
|
253 |
+
error_detail += f" Response: {e.response.text[:500]}"
|
254 |
+
status_message = f"Submission Failed: {error_detail}"
|
255 |
+
print(status_message)
|
256 |
+
results_df = pd.DataFrame(results_log)
|
257 |
+
return status_message, results_df
|
258 |
+
except requests.exceptions.Timeout:
|
259 |
+
status_message = "Submission Failed: The request timed out."
|
260 |
+
print(status_message)
|
261 |
+
results_df = pd.DataFrame(results_log)
|
262 |
+
return status_message, results_df
|
263 |
+
except requests.exceptions.RequestException as e:
|
264 |
+
status_message = f"Submission Failed: Network error - {e}"
|
265 |
+
print(status_message)
|
266 |
+
results_df = pd.DataFrame(results_log)
|
267 |
+
return status_message, results_df
|
268 |
+
except Exception as e:
|
269 |
+
status_message = f"An unexpected error occurred during submission: {e}"
|
270 |
+
print(status_message)
|
271 |
+
results_df = pd.DataFrame(results_log)
|
272 |
+
return status_message, results_df
|
273 |
+
|
274 |
+
|
275 |
+
# Function to test a single random question
|
276 |
+
def test_random_question():
|
277 |
+
"""Fetch a random question from the API and run the agent on it."""
|
278 |
+
api_url = DEFAULT_API_URL
|
279 |
+
random_question_url = f"{api_url}/random-question"
|
280 |
+
|
281 |
+
try:
|
282 |
+
# Fetch a random question
|
283 |
+
response = requests.get(random_question_url, timeout=15)
|
284 |
+
response.raise_for_status()
|
285 |
+
question_data = response.json()
|
286 |
+
|
287 |
+
if not question_data:
|
288 |
+
return "Error: Received empty response from random question endpoint.", None
|
289 |
+
|
290 |
+
task_id = question_data.get("task_id")
|
291 |
+
question_text = question_data.get("question")
|
292 |
+
|
293 |
+
if not task_id or not question_text:
|
294 |
+
return "Error: Invalid question format received.", None
|
295 |
+
|
296 |
+
# Initialize agent and get answer
|
297 |
+
agent = GaiaAgent()
|
298 |
+
answer = agent(question_text)
|
299 |
+
|
300 |
+
# Return results
|
301 |
+
result = {
|
302 |
+
"Task ID": task_id,
|
303 |
+
"Question": question_text,
|
304 |
+
"Answer": answer
|
305 |
+
}
|
306 |
+
|
307 |
+
return "Test completed successfully.", result
|
308 |
+
|
309 |
+
except Exception as e:
|
310 |
+
return f"Error testing random question: {str(e)}", None
|
311 |
+
|
312 |
+
|
313 |
+
# --- Build Gradio Interface using Blocks ---
|
314 |
+
with gr.Blocks() as demo:
|
315 |
+
gr.Markdown("# GAIA Benchmark Agent Evaluation")
|
316 |
+
gr.Markdown(
|
317 |
+
"""
|
318 |
+
**Instructions:**
|
319 |
+
|
320 |
+
1. Log in to your Hugging Face account using the button below. This uses your HF username for submission.
|
321 |
+
2. Click 'Run Evaluation & Submit All Answers' to fetch questions, run the agent, submit answers, and see the score.
|
322 |
+
3. Alternatively, click 'Test on Random Question' to test the agent on a single random question.
|
323 |
+
|
324 |
+
---
|
325 |
+
**Note:** Running the agent on all questions may take some time. Please be patient while the agent processes all the questions.
|
326 |
+
"""
|
327 |
+
)
|
328 |
+
|
329 |
+
gr.LoginButton()
|
330 |
+
|
331 |
+
with gr.Tabs():
|
332 |
+
with gr.TabItem("Full Evaluation"):
|
333 |
+
run_button = gr.Button("Run Evaluation & Submit All Answers")
|
334 |
+
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
|
335 |
+
results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
|
336 |
+
|
337 |
+
run_button.click(
|
338 |
+
fn=run_and_submit_all,
|
339 |
+
outputs=[status_output, results_table]
|
340 |
+
)
|
341 |
+
|
342 |
+
with gr.TabItem("Test Single Question"):
|
343 |
+
test_button = gr.Button("Test on Random Question")
|
344 |
+
test_status = gr.Textbox(label="Test Status", lines=2, interactive=False)
|
345 |
+
test_result = gr.JSON(label="Question and Answer")
|
346 |
+
|
347 |
+
test_button.click(
|
348 |
+
fn=test_random_question,
|
349 |
+
outputs=[test_status, test_result]
|
350 |
+
)
|
351 |
+
|
352 |
+
|
353 |
+
if __name__ == "__main__":
|
354 |
+
print("\n" + "-"*30 + " App Starting " + "-"*30)
|
355 |
+
# Check for SPACE_HOST and SPACE_ID at startup for information
|
356 |
+
space_host_startup = os.getenv("SPACE_HOST")
|
357 |
+
space_id_startup = os.getenv("SPACE_ID") # Get SPACE_ID at startup
|
358 |
+
|
359 |
+
if space_host_startup:
|
360 |
+
print(f"✅ SPACE_HOST found: {space_host_startup}")
|
361 |
+
print(f" Runtime URL should be: https://{space_host_startup}.hf.space")
|
362 |
+
else:
|
363 |
+
print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
|
364 |
+
|
365 |
+
if space_id_startup: # Print repo URLs if SPACE_ID is found
|
366 |
+
print(f"✅ SPACE_ID found: {space_id_startup}")
|
367 |
+
print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
|
368 |
+
print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main")
|
369 |
+
else:
|
370 |
+
print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.")
|
371 |
+
|
372 |
+
print("-"*(60 + len(" App Starting ")) + "\n")
|
373 |
+
|
374 |
+
print("Launching Gradio Interface for GAIA Agent Evaluation...")
|
375 |
+
demo.launch(debug=True, share=False)
|
configuration.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Define the configurable parameters for the agent supervisor system."""
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
from dataclasses import dataclass, field, fields
|
6 |
+
from typing import Annotated
|
7 |
+
|
8 |
+
from langchain_core.runnables import ensure_config
|
9 |
+
from langgraph.config import get_config
|
10 |
+
|
11 |
+
from react_agent import prompts
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass(kw_only=True)
|
15 |
+
class Configuration:
|
16 |
+
"""The configuration for the agent supervisor system."""
|
17 |
+
|
18 |
+
# Supervisor configuration
|
19 |
+
supervisor_prompt: str = field(
|
20 |
+
default=prompts.SUPERVISOR_PROMPT,
|
21 |
+
metadata={
|
22 |
+
"description": "The system prompt for the supervisor agent. "
|
23 |
+
"This prompt guides how the supervisor delegates tasks to worker agents."
|
24 |
+
},
|
25 |
+
)
|
26 |
+
|
27 |
+
# Planner configuration
|
28 |
+
planner_prompt: str = field(
|
29 |
+
default=prompts.PLANNER_PROMPT,
|
30 |
+
metadata={
|
31 |
+
"description": "The system prompt for the planner agent. "
|
32 |
+
"This prompt guides how the planner creates structured plans."
|
33 |
+
},
|
34 |
+
)
|
35 |
+
|
36 |
+
# Critic configuration
|
37 |
+
critic_prompt: str = field(
|
38 |
+
default=prompts.CRITIC_PROMPT,
|
39 |
+
metadata={
|
40 |
+
"description": "The system prompt for the critic agent. "
|
41 |
+
"This prompt guides how the critic evaluates answers."
|
42 |
+
},
|
43 |
+
)
|
44 |
+
|
45 |
+
# Worker agents configuration
|
46 |
+
researcher_prompt: str = field(
|
47 |
+
default=prompts.RESEARCHER_PROMPT,
|
48 |
+
metadata={
|
49 |
+
"description": "The system prompt for the researcher agent. "
|
50 |
+
"This prompt defines the researcher's capabilities and limitations."
|
51 |
+
},
|
52 |
+
)
|
53 |
+
|
54 |
+
coder_prompt: str = field(
|
55 |
+
default=prompts.CODER_PROMPT,
|
56 |
+
metadata={
|
57 |
+
"description": "The system prompt for the coder agent. "
|
58 |
+
"This prompt defines the coder's capabilities and approach to programming tasks."
|
59 |
+
},
|
60 |
+
)
|
61 |
+
|
62 |
+
# Shared configuration
|
63 |
+
system_prompt: str = field(
|
64 |
+
default=prompts.SYSTEM_PROMPT,
|
65 |
+
metadata={
|
66 |
+
"description": "Legacy system prompt for backward compatibility. "
|
67 |
+
"This prompt is used when running the agent in non-supervisor mode."
|
68 |
+
},
|
69 |
+
)
|
70 |
+
|
71 |
+
# LLM Configuration - Default model for backward compatibility
|
72 |
+
model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
|
73 |
+
default="openai/gpt-4o-mini",
|
74 |
+
metadata={
|
75 |
+
"description": "The default large language model used by the agents (provider/model_name)."
|
76 |
+
},
|
77 |
+
)
|
78 |
+
|
79 |
+
# Model for the researcher (information gathering) - use powerful model
|
80 |
+
researcher_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
|
81 |
+
default="openai/gpt-4o-mini",
|
82 |
+
metadata={
|
83 |
+
"description": "The model used by the researcher agent for gathering information (provider/model_name)."
|
84 |
+
},
|
85 |
+
)
|
86 |
+
|
87 |
+
# Model for the coder (code execution) - use Claude Sonnet
|
88 |
+
coder_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
|
89 |
+
default="anthropic/claude-3-5-sonnet-20240620",
|
90 |
+
metadata={
|
91 |
+
"description": "The model used by the coder agent for programming tasks (provider/model_name)."
|
92 |
+
},
|
93 |
+
)
|
94 |
+
|
95 |
+
# Model for lightweight reasoning tasks (planner, supervisor, critic)
|
96 |
+
planner_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
|
97 |
+
default="google_genai/gemini-1.5-flash",
|
98 |
+
metadata={
|
99 |
+
"description": "The lightweight reasoning model used by the planner, supervisor, and critic (provider/model_name)."
|
100 |
+
},
|
101 |
+
)
|
102 |
+
|
103 |
+
# Same model used for supervisor and critic (points to planner_model)
|
104 |
+
supervisor_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
|
105 |
+
default="google_genai/gemini-1.5-flash",
|
106 |
+
metadata={
|
107 |
+
"description": "The model used by the supervisor for routing (provider/model_name)."
|
108 |
+
},
|
109 |
+
)
|
110 |
+
|
111 |
+
critic_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
|
112 |
+
default="openai/gpt-4o-mini",
|
113 |
+
metadata={
|
114 |
+
"description": "The model used by the critic for evaluation (provider/model_name)."
|
115 |
+
},
|
116 |
+
)
|
117 |
+
|
118 |
+
# Model for final answer generation - using Claude for precise formatting
|
119 |
+
final_answer_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
|
120 |
+
default="anthropic/claude-3-5-sonnet-20240620",
|
121 |
+
metadata={
|
122 |
+
"description": "The model used for generating the final answers in GAIA benchmark format (provider/model_name)."
|
123 |
+
},
|
124 |
+
)
|
125 |
+
|
126 |
+
# Tool Configuration
|
127 |
+
max_search_results: int = field(
|
128 |
+
default=5,
|
129 |
+
metadata={
|
130 |
+
"description": "The maximum number of search results to return."
|
131 |
+
},
|
132 |
+
)
|
133 |
+
|
134 |
+
# Execution Configuration
|
135 |
+
recursion_limit: int = field(
|
136 |
+
default=50,
|
137 |
+
metadata={
|
138 |
+
"description": "Maximum number of recursion steps allowed in the LangGraph execution."
|
139 |
+
},
|
140 |
+
)
|
141 |
+
|
142 |
+
max_iterations: int = field(
|
143 |
+
default=12,
|
144 |
+
metadata={
|
145 |
+
"description": "Maximum number of iterations allowed to prevent infinite loops."
|
146 |
+
},
|
147 |
+
)
|
148 |
+
|
149 |
+
allow_agent_to_extract_answers: bool = field(
|
150 |
+
default=True,
|
151 |
+
metadata={
|
152 |
+
"description": "Whether to allow the agent to extract answers from context when formatting fails."
|
153 |
+
},
|
154 |
+
)
|
155 |
+
|
156 |
+
@classmethod
|
157 |
+
def from_context(cls) -> Configuration:
|
158 |
+
"""Create a Configuration instance from a RunnableConfig object."""
|
159 |
+
try:
|
160 |
+
config = get_config()
|
161 |
+
except RuntimeError:
|
162 |
+
config = None
|
163 |
+
config = ensure_config(config)
|
164 |
+
configurable = config.get("configurable") or {}
|
165 |
+
_fields = {f.name for f in fields(cls) if f.init}
|
166 |
+
return cls(**{k: v for k, v in configurable.items() if k in _fields})
|
graph.py
ADDED
@@ -0,0 +1,574 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Define an Agent Supervisor graph with specialized worker agents.
|
2 |
+
|
3 |
+
The supervisor routes tasks to specialized agents based on the query type.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from typing import Dict, List, Literal, Optional, Union, Type, cast
|
7 |
+
|
8 |
+
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
9 |
+
from langchain_core.prompts import ChatPromptTemplate
|
10 |
+
from langgraph.graph import StateGraph, START, END
|
11 |
+
# Import adjusted for compatibility
|
12 |
+
from langgraph.prebuilt import create_react_agent # Try original import path first
|
13 |
+
from langgraph.types import Command
|
14 |
+
|
15 |
+
from react_agent.configuration import Configuration
|
16 |
+
from react_agent.state import WORKERS, MEMBERS, ROUTING, VERDICTS, State, Router, Plan, PlanStep, CriticVerdict
|
17 |
+
from react_agent.tools import TOOLS, tavily_tool, python_repl_tool
|
18 |
+
from react_agent.utils import load_chat_model, format_system_prompt, get_message_text
|
19 |
+
from react_agent import prompts
|
20 |
+
from react_agent.supervisor_node import supervisor_node
|
21 |
+
|
22 |
+
|
23 |
+
# Compile-time type definitions
|
24 |
+
SupervisorDestinations = Literal["planner", "critic", "researcher", "coder", "final_answer", "__end__"]
|
25 |
+
WorkerDestination = Literal["supervisor"]
|
26 |
+
|
27 |
+
|
28 |
+
# Helper function to check if a message is from a user
|
29 |
+
def is_user_message(message):
|
30 |
+
"""Check if a message is from a user regardless of message format."""
|
31 |
+
if isinstance(message, dict):
|
32 |
+
return message.get("role") == "user"
|
33 |
+
elif isinstance(message, HumanMessage):
|
34 |
+
return True
|
35 |
+
return False
|
36 |
+
|
37 |
+
|
38 |
+
# Helper function to get message content
|
39 |
+
def get_message_content(message):
|
40 |
+
"""Extract content from a message regardless of format."""
|
41 |
+
if isinstance(message, dict):
|
42 |
+
return message.get("content", "")
|
43 |
+
elif hasattr(message, "content"):
|
44 |
+
return message.content
|
45 |
+
return ""
|
46 |
+
|
47 |
+
|
48 |
+
# --- Planner node ---------------------------------------------------------
|
49 |
+
|
50 |
+
def planner_node(state: State) -> Command[WorkerDestination]:
|
51 |
+
"""Planning LLM that creates a step-by-step execution plan.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
state: The current state with messages
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
Command to update the state with a plan
|
58 |
+
"""
|
59 |
+
configuration = Configuration.from_context()
|
60 |
+
# Use the specific planner model
|
61 |
+
planner_llm = load_chat_model(configuration.planner_model)
|
62 |
+
|
63 |
+
# Track steps
|
64 |
+
steps_taken = state.get("steps_taken", 0)
|
65 |
+
steps_taken += 1
|
66 |
+
|
67 |
+
# Get the original user question (the latest user message)
|
68 |
+
user_messages = [m for m in state["messages"] if is_user_message(m)]
|
69 |
+
original_question = get_message_content(user_messages[-1]) if user_messages else "Help me"
|
70 |
+
|
71 |
+
# Create a chat prompt template with proper formatting
|
72 |
+
planner_prompt_template = ChatPromptTemplate.from_messages([
|
73 |
+
("system", prompts.PLANNER_PROMPT),
|
74 |
+
("user", "{question}")
|
75 |
+
])
|
76 |
+
|
77 |
+
# Format the prompt with the necessary variables
|
78 |
+
formatted_messages = planner_prompt_template.format_messages(
|
79 |
+
question=original_question,
|
80 |
+
system_time=format_system_prompt("{system_time}"),
|
81 |
+
workers=", ".join(WORKERS),
|
82 |
+
worker_options=", ".join([f'"{w}"' for w in WORKERS]),
|
83 |
+
example_worker_1=WORKERS[0] if WORKERS else "researcher",
|
84 |
+
example_worker_2=WORKERS[1] if len(WORKERS) > 1 else "coder"
|
85 |
+
)
|
86 |
+
|
87 |
+
# Get structured output from the planner model
|
88 |
+
plan = planner_llm.with_structured_output(Plan).invoke(formatted_messages)
|
89 |
+
|
90 |
+
# Return with updated state
|
91 |
+
return Command(
|
92 |
+
goto="supervisor",
|
93 |
+
update={
|
94 |
+
"plan": plan,
|
95 |
+
"current_step_index": 0,
|
96 |
+
# Add a message to show the plan was created
|
97 |
+
"messages": [
|
98 |
+
HumanMessage(
|
99 |
+
content=f"Created plan with {len(plan['steps'])} steps",
|
100 |
+
name="planner"
|
101 |
+
)
|
102 |
+
],
|
103 |
+
"steps_taken": steps_taken
|
104 |
+
}
|
105 |
+
)
|
106 |
+
|
107 |
+
|
108 |
+
# --- Final Answer node -----------------------------------------------------
|
109 |
+
|
110 |
+
def final_answer_node(state: State) -> Command[Literal["__end__"]]:
|
111 |
+
"""Generate a final answer based on gathered information.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
state: The current state with messages and context
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
Command with final answer
|
118 |
+
"""
|
119 |
+
configuration = Configuration.from_context()
|
120 |
+
|
121 |
+
# Track steps
|
122 |
+
steps_taken = state.get("steps_taken", 0)
|
123 |
+
steps_taken += 1
|
124 |
+
|
125 |
+
# Check if we've exhausted retries and already have a draft answer
|
126 |
+
retry_exhausted = state.get("retry_exhausted", False)
|
127 |
+
draft_answer = state.get("draft_answer")
|
128 |
+
|
129 |
+
# Variable to store the final answer
|
130 |
+
gaia_answer = ""
|
131 |
+
|
132 |
+
if retry_exhausted and draft_answer and draft_answer.startswith("FINAL ANSWER:"):
|
133 |
+
# If supervisor already provided a properly formatted answer after exhausting retries,
|
134 |
+
# use it directly without calling the model again
|
135 |
+
import re
|
136 |
+
final_answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", draft_answer, re.IGNORECASE)
|
137 |
+
if final_answer_match:
|
138 |
+
gaia_answer = final_answer_match.group(1).strip()
|
139 |
+
else:
|
140 |
+
gaia_answer = "unknown"
|
141 |
+
else:
|
142 |
+
# Use the specific final answer model
|
143 |
+
final_llm = load_chat_model(configuration.final_answer_model)
|
144 |
+
|
145 |
+
# Get the original user question (the latest user message)
|
146 |
+
user_messages = [m for m in state["messages"] if is_user_message(m)]
|
147 |
+
original_question = get_message_content(user_messages[-1]) if user_messages else "Help me"
|
148 |
+
|
149 |
+
# Check if we already have a draft answer from supervisor
|
150 |
+
if draft_answer and draft_answer.startswith("FINAL ANSWER:"):
|
151 |
+
# If supervisor already provided a properly formatted answer, use it directly
|
152 |
+
raw_answer = draft_answer
|
153 |
+
else:
|
154 |
+
# Get the context and worker results
|
155 |
+
context = state.get("context", {})
|
156 |
+
worker_results = state.get("worker_results", {})
|
157 |
+
|
158 |
+
# Compose a prompt for the final answer using the GAIA-specific format
|
159 |
+
final_prompt = ChatPromptTemplate.from_messages([
|
160 |
+
("system", prompts.FINAL_ANSWER_PROMPT),
|
161 |
+
("user", prompts.FINAL_ANSWER_USER_PROMPT)
|
162 |
+
])
|
163 |
+
|
164 |
+
# Format the context information more effectively
|
165 |
+
context_list = []
|
166 |
+
# First include researcher context as it provides background
|
167 |
+
if "researcher" in context:
|
168 |
+
context_list.append(f"Research information: {context['researcher']}")
|
169 |
+
|
170 |
+
# Then include coder results which are typically calculations
|
171 |
+
if "coder" in context:
|
172 |
+
context_list.append(f"Calculation results: {context['coder']}")
|
173 |
+
|
174 |
+
# Add any other workers
|
175 |
+
for worker, content in context.items():
|
176 |
+
if worker not in ["researcher", "coder"]:
|
177 |
+
context_list.append(f"{worker.capitalize()}: {content}")
|
178 |
+
|
179 |
+
# Get the final answer
|
180 |
+
formatted_messages = final_prompt.format_messages(
|
181 |
+
question=original_question,
|
182 |
+
context="\n\n".join(context_list)
|
183 |
+
)
|
184 |
+
|
185 |
+
raw_answer = final_llm.invoke(formatted_messages).content
|
186 |
+
|
187 |
+
# Extract the answer in GAIA format: "FINAL ANSWER: [x]"
|
188 |
+
import re
|
189 |
+
gaia_answer = raw_answer
|
190 |
+
final_answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", raw_answer, re.IGNORECASE)
|
191 |
+
if final_answer_match:
|
192 |
+
gaia_answer = final_answer_match.group(1).strip()
|
193 |
+
|
194 |
+
# Ensure answer is properly formatted - if we don't have a valid answer
|
195 |
+
# but have sufficient context, try to extract directly
|
196 |
+
if configuration.allow_agent_to_extract_answers and (not gaia_answer or gaia_answer.lower() in ["unknown", "insufficient information"]):
|
197 |
+
context = state.get("context", {})
|
198 |
+
from react_agent.supervisor_node import extract_best_answer_from_context
|
199 |
+
extracted_answer = extract_best_answer_from_context(context)
|
200 |
+
if extracted_answer != "unknown":
|
201 |
+
gaia_answer = extracted_answer
|
202 |
+
|
203 |
+
# Set status to "final_answer_generated" to indicate we're done
|
204 |
+
return Command(
|
205 |
+
goto=END,
|
206 |
+
update={
|
207 |
+
"messages": [
|
208 |
+
AIMessage(
|
209 |
+
content=f"FINAL ANSWER: {gaia_answer}",
|
210 |
+
name="supervisor"
|
211 |
+
)
|
212 |
+
],
|
213 |
+
"next": "FINISH", # Update next to indicate we're done
|
214 |
+
"gaia_answer": gaia_answer, # Store answer in GAIA-compatible format
|
215 |
+
"submitted_answer": gaia_answer, # Store as submitted_answer for GAIA benchmark
|
216 |
+
"status": "final_answer_generated", # Add status to indicate we're complete
|
217 |
+
"steps_taken": steps_taken
|
218 |
+
}
|
219 |
+
)
|
220 |
+
|
221 |
+
|
222 |
+
# --- Critic node ----------------------------------------------------------
|
223 |
+
|
224 |
+
def critic_node(state: State) -> Command[Union[WorkerDestination, SupervisorDestinations]]:
|
225 |
+
"""Critic that evaluates if the answer fully satisfies the request.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
state: The current state with messages and draft answer
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
Command with evaluation verdict
|
232 |
+
"""
|
233 |
+
configuration = Configuration.from_context()
|
234 |
+
# Use the specific critic model
|
235 |
+
critic_llm = load_chat_model(configuration.critic_model)
|
236 |
+
|
237 |
+
# Track steps
|
238 |
+
steps_taken = state.get("steps_taken", 0)
|
239 |
+
steps_taken += 1
|
240 |
+
|
241 |
+
# Get the original user question (the latest user message)
|
242 |
+
user_messages = [m for m in state["messages"] if is_user_message(m)]
|
243 |
+
original_question = get_message_content(user_messages[-1]) if user_messages else "Help me"
|
244 |
+
|
245 |
+
# Get the draft answer
|
246 |
+
draft_answer = state.get("draft_answer", "No answer provided.")
|
247 |
+
|
248 |
+
# Create a chat prompt template with proper formatting
|
249 |
+
critic_prompt_template = ChatPromptTemplate.from_messages([
|
250 |
+
("system", prompts.CRITIC_PROMPT),
|
251 |
+
("user", prompts.CRITIC_USER_PROMPT)
|
252 |
+
])
|
253 |
+
|
254 |
+
# Format the prompt with the necessary variables
|
255 |
+
formatted_messages = critic_prompt_template.format_messages(
|
256 |
+
question=original_question,
|
257 |
+
answer=draft_answer,
|
258 |
+
system_time=format_system_prompt("{system_time}"),
|
259 |
+
correct_verdict=VERDICTS[0] if VERDICTS else "CORRECT",
|
260 |
+
retry_verdict=VERDICTS[1] if len(VERDICTS) > 1 else "RETRY"
|
261 |
+
)
|
262 |
+
|
263 |
+
# Get structured output from the critic model
|
264 |
+
verdict = critic_llm.with_structured_output(CriticVerdict).invoke(formatted_messages)
|
265 |
+
|
266 |
+
# Add a message about the verdict
|
267 |
+
if verdict["verdict"] == VERDICTS[0]: # CORRECT
|
268 |
+
verdict_message = "Answer is complete, accurate, and properly formatted for GAIA."
|
269 |
+
goto = "final_answer" # Go to final answer node if correct
|
270 |
+
else:
|
271 |
+
verdict_message = f"Answer needs improvement. Reason: {verdict.get('reason', 'Unknown')}"
|
272 |
+
goto = "supervisor"
|
273 |
+
|
274 |
+
# Return with updated state
|
275 |
+
return Command(
|
276 |
+
goto=goto,
|
277 |
+
update={
|
278 |
+
"critic_verdict": verdict,
|
279 |
+
"messages": [
|
280 |
+
HumanMessage(
|
281 |
+
content=verdict_message,
|
282 |
+
name="critic"
|
283 |
+
)
|
284 |
+
],
|
285 |
+
"steps_taken": steps_taken
|
286 |
+
}
|
287 |
+
)
|
288 |
+
|
289 |
+
|
290 |
+
# --- Worker agent factory -------------------------------------------------
|
291 |
+
|
292 |
+
def create_worker_node(worker_type: str):
|
293 |
+
"""Factory function to create a worker node of the specified type.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
worker_type: The type of worker to create (must be in WORKERS)
|
297 |
+
|
298 |
+
Returns:
|
299 |
+
A function that processes requests for the specified worker type
|
300 |
+
"""
|
301 |
+
if worker_type not in WORKERS:
|
302 |
+
raise ValueError(f"Unknown worker type: {worker_type}")
|
303 |
+
|
304 |
+
configuration = Configuration.from_context()
|
305 |
+
|
306 |
+
# Select the appropriate model for each worker type
|
307 |
+
if worker_type == "researcher":
|
308 |
+
llm = load_chat_model(configuration.researcher_model)
|
309 |
+
worker_prompt = prompts.RESEARCHER_PROMPT
|
310 |
+
worker_tools = [tavily_tool]
|
311 |
+
elif worker_type == "coder":
|
312 |
+
llm = load_chat_model(configuration.coder_model)
|
313 |
+
worker_prompt = prompts.CODER_PROMPT
|
314 |
+
worker_tools = [python_repl_tool]
|
315 |
+
else:
|
316 |
+
# Default case
|
317 |
+
llm = load_chat_model(configuration.model)
|
318 |
+
worker_prompt = getattr(prompts, f"{worker_type.upper()}_PROMPT", prompts.SYSTEM_PROMPT)
|
319 |
+
worker_tools = TOOLS
|
320 |
+
|
321 |
+
# Create the agent
|
322 |
+
worker_agent = create_react_agent(
|
323 |
+
llm,
|
324 |
+
tools=worker_tools,
|
325 |
+
prompt=format_system_prompt(worker_prompt)
|
326 |
+
)
|
327 |
+
|
328 |
+
# Define node function
|
329 |
+
def worker_node(state: State) -> Command[WorkerDestination]:
|
330 |
+
"""Process requests using the specified worker.
|
331 |
+
|
332 |
+
Args:
|
333 |
+
state: The current conversation state
|
334 |
+
|
335 |
+
Returns:
|
336 |
+
Command to return to supervisor with results
|
337 |
+
"""
|
338 |
+
# Track steps
|
339 |
+
steps_taken = state.get("steps_taken", 0)
|
340 |
+
steps_taken += 1
|
341 |
+
|
342 |
+
# Get the last message from the supervisor, which contains our task
|
343 |
+
task_message = None
|
344 |
+
if state.get("messages"):
|
345 |
+
for msg in reversed(state["messages"]):
|
346 |
+
if hasattr(msg, "name") and msg.name == "supervisor":
|
347 |
+
task_message = msg
|
348 |
+
break
|
349 |
+
|
350 |
+
if not task_message:
|
351 |
+
return Command(
|
352 |
+
goto="supervisor",
|
353 |
+
update={
|
354 |
+
"messages": [
|
355 |
+
HumanMessage(
|
356 |
+
content=f"Error: No task message found for {worker_type}",
|
357 |
+
name=worker_type
|
358 |
+
)
|
359 |
+
],
|
360 |
+
"steps_taken": steps_taken
|
361 |
+
}
|
362 |
+
)
|
363 |
+
|
364 |
+
# Create a new state with just the relevant messages for this worker
|
365 |
+
# This prevents confusion from unrelated parts of the conversation
|
366 |
+
agent_input = {
|
367 |
+
"messages": [
|
368 |
+
# Include the first user message for context
|
369 |
+
state["messages"][0] if state["messages"] else HumanMessage(content="Help me"),
|
370 |
+
# Include the task message
|
371 |
+
task_message
|
372 |
+
]
|
373 |
+
}
|
374 |
+
|
375 |
+
# Invoke the agent with the clean input
|
376 |
+
result = worker_agent.invoke(agent_input)
|
377 |
+
|
378 |
+
# Extract the result from the agent response
|
379 |
+
result_content = extract_worker_result(worker_type, result, state)
|
380 |
+
|
381 |
+
# Store the worker's result in shared context
|
382 |
+
context_update = state.get("context", {}).copy()
|
383 |
+
context_update[worker_type] = result_content
|
384 |
+
|
385 |
+
# Store in worker_results history
|
386 |
+
worker_results = state.get("worker_results", {}).copy()
|
387 |
+
if worker_type not in worker_results:
|
388 |
+
worker_results[worker_type] = []
|
389 |
+
worker_results[worker_type].append(result_content)
|
390 |
+
|
391 |
+
# Increment the step index after worker completes
|
392 |
+
current_step_index = state.get("current_step_index", 0)
|
393 |
+
|
394 |
+
return Command(
|
395 |
+
update={
|
396 |
+
"messages": [
|
397 |
+
HumanMessage(content=result_content, name=worker_type)
|
398 |
+
],
|
399 |
+
"current_step_index": current_step_index + 1,
|
400 |
+
"context": context_update,
|
401 |
+
"worker_results": worker_results,
|
402 |
+
"steps_taken": steps_taken
|
403 |
+
},
|
404 |
+
goto="supervisor",
|
405 |
+
)
|
406 |
+
|
407 |
+
return worker_node
|
408 |
+
|
409 |
+
|
410 |
+
def extract_worker_result(worker_type: str, result: dict, state: State) -> str:
|
411 |
+
"""Extract a clean, useful result from the worker's output.
|
412 |
+
|
413 |
+
This handles different response formats from different worker types.
|
414 |
+
|
415 |
+
Args:
|
416 |
+
worker_type: The type of worker (researcher or coder)
|
417 |
+
result: The raw result from the worker agent
|
418 |
+
state: The current state for context
|
419 |
+
|
420 |
+
Returns:
|
421 |
+
A cleaned string with the relevant result information
|
422 |
+
"""
|
423 |
+
# Handle empty results
|
424 |
+
if not result or "messages" not in result or not result["messages"]:
|
425 |
+
return f"No output from {worker_type}"
|
426 |
+
|
427 |
+
# Get the last message from the agent
|
428 |
+
last_message = result["messages"][-1]
|
429 |
+
|
430 |
+
# Default to extracting content directly
|
431 |
+
if hasattr(last_message, "content") and last_message.content:
|
432 |
+
result_content = last_message.content
|
433 |
+
else:
|
434 |
+
result_content = f"No content from {worker_type}"
|
435 |
+
|
436 |
+
# Special handling based on worker type
|
437 |
+
if worker_type == "coder":
|
438 |
+
# For coder outputs, extract the actual result values from code execution
|
439 |
+
if "```" in result_content:
|
440 |
+
# Try to extract stdout from code execution
|
441 |
+
import re
|
442 |
+
stdout_match = re.search(r"Stdout:\s*(.*?)(?:\n\n|$)", result_content, re.DOTALL)
|
443 |
+
if stdout_match:
|
444 |
+
# Extract the actual execution output, not just the code
|
445 |
+
execution_result = stdout_match.group(1).strip()
|
446 |
+
if execution_result:
|
447 |
+
# Check if this is just a simple number result
|
448 |
+
if re.match(r"^\d+(\.\d+)?$", execution_result):
|
449 |
+
return execution_result
|
450 |
+
else:
|
451 |
+
return f"Code executed with result: {execution_result}"
|
452 |
+
|
453 |
+
# If we couldn't find stdout, try to extract output in a different way
|
454 |
+
# Look for "Result:" or similar indicators
|
455 |
+
result_match = re.search(r"(?:Result|Output|Answer):\s*(.*?)(?:\n\n|$)", result_content, re.DOTALL)
|
456 |
+
if result_match:
|
457 |
+
return result_match.group(1).strip()
|
458 |
+
|
459 |
+
elif worker_type == "researcher":
|
460 |
+
# For researcher outputs, keep the full detailed response
|
461 |
+
# but ensure it's well-formatted
|
462 |
+
if len(result_content) > 800:
|
463 |
+
# If too long, try to extract key sections
|
464 |
+
# Look for summary or conclusion sections
|
465 |
+
import re
|
466 |
+
summary_match = re.search(r"(?:Summary|Conclusion|To summarize|In summary):(.*?)(?:\n\n|$)",
|
467 |
+
result_content, re.IGNORECASE | re.DOTALL)
|
468 |
+
if summary_match:
|
469 |
+
return summary_match.group(1).strip()
|
470 |
+
|
471 |
+
# If no special handling was triggered, return the content as is
|
472 |
+
return result_content
|
473 |
+
|
474 |
+
|
475 |
+
# --- Graph assembly -------------------------------------------------------
|
476 |
+
|
477 |
+
def create_agent_supervisor_graph() -> StateGraph:
|
478 |
+
"""Create the agent supervisor graph with all nodes and edges.
|
479 |
+
|
480 |
+
Returns:
|
481 |
+
Compiled StateGraph ready for execution
|
482 |
+
"""
|
483 |
+
# Initialize the graph with our State type
|
484 |
+
builder = StateGraph(State)
|
485 |
+
|
486 |
+
# Add control nodes
|
487 |
+
builder.add_node("planner", planner_node)
|
488 |
+
builder.add_node("supervisor", supervisor_node)
|
489 |
+
builder.add_node("critic", critic_node)
|
490 |
+
builder.add_node("final_answer", final_answer_node)
|
491 |
+
|
492 |
+
# Add worker nodes dynamically based on WORKERS list
|
493 |
+
for worker_type in WORKERS:
|
494 |
+
builder.add_node(worker_type, create_worker_node(worker_type))
|
495 |
+
|
496 |
+
# Define the workflow
|
497 |
+
builder.add_edge(START, "supervisor")
|
498 |
+
builder.add_edge("planner", "supervisor")
|
499 |
+
builder.add_edge("critic", "supervisor")
|
500 |
+
builder.add_edge("critic", "final_answer") # Add edge from critic to final_answer
|
501 |
+
builder.add_edge("final_answer", END) # Final answer node goes to END
|
502 |
+
builder.add_edge("supervisor", END) # Allow the supervisor to end the workflow
|
503 |
+
|
504 |
+
# Connect all workers to supervisor
|
505 |
+
for worker_type in WORKERS:
|
506 |
+
builder.add_edge(worker_type, "supervisor")
|
507 |
+
|
508 |
+
# Return the builder, not a compiled graph
|
509 |
+
# This allows the caller to compile with a checkpointer
|
510 |
+
return builder
|
511 |
+
|
512 |
+
|
513 |
+
# --- Graph instantiation (with flexible checkpointing) -----------------------------
|
514 |
+
|
515 |
+
def get_compiled_graph(checkpointer=None):
|
516 |
+
"""Get a compiled graph with optional checkpointer.
|
517 |
+
|
518 |
+
Args:
|
519 |
+
checkpointer: Optional checkpointer for persistence
|
520 |
+
|
521 |
+
Returns:
|
522 |
+
Compiled StateGraph ready for execution
|
523 |
+
"""
|
524 |
+
# Get configuration
|
525 |
+
configuration = Configuration.from_context()
|
526 |
+
|
527 |
+
builder = create_agent_supervisor_graph()
|
528 |
+
|
529 |
+
# Define termination condition function to prevent loops
|
530 |
+
def should_end(state):
|
531 |
+
"""Determine if the graph should terminate."""
|
532 |
+
# End if status is set to final_answer_generated
|
533 |
+
if state.get("status") == "final_answer_generated":
|
534 |
+
return True
|
535 |
+
|
536 |
+
# End if retry_exhausted flag is set and we've gone through final_answer
|
537 |
+
if state.get("retry_exhausted") and state.get("gaia_answer"):
|
538 |
+
return True
|
539 |
+
|
540 |
+
# End if we've hit maximum recursion limit defined by LangGraph
|
541 |
+
steps_taken = state.get("steps_taken", 0)
|
542 |
+
if steps_taken >= configuration.recursion_limit - 5: # Leave buffer
|
543 |
+
return True
|
544 |
+
|
545 |
+
return False
|
546 |
+
|
547 |
+
# Define step counter for tracking step count
|
548 |
+
def count_steps(state):
|
549 |
+
"""Count steps to prevent infinite loops."""
|
550 |
+
steps_taken = state.get("steps_taken", 0)
|
551 |
+
return {"steps_taken": steps_taken + 1}
|
552 |
+
|
553 |
+
# Compile the graph (don't use add_state_transform which isn't available)
|
554 |
+
if checkpointer:
|
555 |
+
graph = builder.compile(
|
556 |
+
checkpointer=checkpointer,
|
557 |
+
name="Structured Reasoning Loop"
|
558 |
+
)
|
559 |
+
else:
|
560 |
+
graph = builder.compile(
|
561 |
+
name="Structured Reasoning Loop"
|
562 |
+
)
|
563 |
+
|
564 |
+
# Configure the graph with recursion limit and max iterations
|
565 |
+
graph = graph.with_config({
|
566 |
+
"recursion_limit": configuration.recursion_limit,
|
567 |
+
"max_iterations": configuration.max_iterations
|
568 |
+
})
|
569 |
+
|
570 |
+
return graph
|
571 |
+
|
572 |
+
|
573 |
+
# Initialize a default non-checkpointed graph (for backward compatibility)
|
574 |
+
graph = get_compiled_graph()
|
prompts.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""System prompts used by the agent supervisor and worker agents."""
|
2 |
+
|
3 |
+
from react_agent.state import WORKERS, VERDICTS
|
4 |
+
|
5 |
+
# --- Supervisor prompt -----------------------------------------------------
|
6 |
+
|
7 |
+
SUPERVISOR_PROMPT = """You are a supervisor tasked with managing a conversation between the \
|
8 |
+
following workers: {workers}. Given the following user request, \
|
9 |
+
respond with the worker to act next. Each worker will perform a \
|
10 |
+
task and respond with their results and status. When finished, \
|
11 |
+
respond with FINISH.
|
12 |
+
|
13 |
+
System time: {system_time}"""
|
14 |
+
|
15 |
+
# --- Planner prompt -------------------------------------------------------
|
16 |
+
|
17 |
+
PLANNER_PROMPT = """**Role**: You are a Planner node in a LangGraph supervisor workflow
|
18 |
+
**Goal**: Given the user's original request, create a concise, focused plan that directly answers the question.
|
19 |
+
|
20 |
+
Requirements:
|
21 |
+
1. Output only a JSON object with one key `steps`, whose value is an **ordered list** of at least 1 and at most 3 objects.
|
22 |
+
Each object has:
|
23 |
+
• `worker` – one of: {worker_options}
|
24 |
+
• `instruction` – ≤ 20 words telling that worker what to do
|
25 |
+
|
26 |
+
2. Your plan MUST:
|
27 |
+
• Directly address the user's specific question
|
28 |
+
• Include at least one step (never return empty steps)
|
29 |
+
• Be focused on finding the exact answer requested, not the process of answering
|
30 |
+
• Use researcher for information gathering
|
31 |
+
• Use coder for calculations or data analysis if needed
|
32 |
+
|
33 |
+
3. Common tasks:
|
34 |
+
• For factual questions: use researcher to find the specific fact
|
35 |
+
• For calculations: use researcher to find data, then coder to calculate
|
36 |
+
• For multiple-part questions: break into steps with the right workers
|
37 |
+
• Ensure your last step gets the exact answer in the format requested
|
38 |
+
|
39 |
+
Example:
|
40 |
+
```
|
41 |
+
{{
|
42 |
+
"steps": [
|
43 |
+
{{"worker": "{example_worker_1}", "instruction": "Find inflation rate in 2023"}},
|
44 |
+
{{"worker": "{example_worker_2}", "instruction": "Compute average of 2019–2023 rates"}}
|
45 |
+
]
|
46 |
+
}}
|
47 |
+
```
|
48 |
+
|
49 |
+
System time: {system_time}"""
|
50 |
+
|
51 |
+
# --- Critic prompt --------------------------------------------------------
|
52 |
+
|
53 |
+
CRITIC_PROMPT = """**Role**: You are a Critic node specializing in GAIA benchmark format validation
|
54 |
+
**Goal**: Strictly check if the answer follows GAIA format requirements
|
55 |
+
|
56 |
+
Requirements:
|
57 |
+
1. You will check if the answer:
|
58 |
+
• Addresses all parts of the user's question correctly
|
59 |
+
• Follows the EXACT required GAIA format: "FINAL ANSWER: [concise response]"
|
60 |
+
• Contains ONLY the essential information in the [concise response]:
|
61 |
+
- A single number (no commas, no units like $ or % unless specified)
|
62 |
+
- A single word or very short phrase
|
63 |
+
- A comma-separated list of numbers or strings
|
64 |
+
• Has NO explanations, reasoning, or extra text
|
65 |
+
• For strings: no articles or abbreviations
|
66 |
+
• For numbers: digits only without commas
|
67 |
+
|
68 |
+
2. If the answer is CORRECT, respond ONLY with this exact JSON:
|
69 |
+
• `{{"verdict":"{correct_verdict}"}}`
|
70 |
+
|
71 |
+
3. If ANY requirement is NOT MET, respond with this JSON including a SPECIFIC reason:
|
72 |
+
• `{{"verdict":"{retry_verdict}","reason":"<specific format issue>"}}`
|
73 |
+
• IMPORTANT: You MUST provide a substantive reason that clearly explains what's wrong
|
74 |
+
• NEVER leave the reason empty or only containing quotes
|
75 |
+
|
76 |
+
4. Common reason examples:
|
77 |
+
• "Answer not formatted as 'FINAL ANSWER: [response]'"
|
78 |
+
• "Answer contains explanations instead of just the concise response"
|
79 |
+
• "Answer does not address the question about [specific topic]"
|
80 |
+
• "Answer contains units when it should just be a number"
|
81 |
+
|
82 |
+
DO NOT include any text before or after the JSON. Your complete response must be valid JSON that can be parsed.
|
83 |
+
|
84 |
+
System time: {system_time}"""
|
85 |
+
|
86 |
+
# --- Critic user prompt ---------------------------------------------------
|
87 |
+
|
88 |
+
CRITIC_USER_PROMPT = """Original question: {question}
|
89 |
+
|
90 |
+
Draft answer: {answer}
|
91 |
+
|
92 |
+
Check if the draft answer follows GAIA format requirements:
|
93 |
+
1. Format must be exactly "FINAL ANSWER: [concise response]"
|
94 |
+
2. [concise response] must ONLY be:
|
95 |
+
- A single number (no commas or units unless specified)
|
96 |
+
- A single word or very short phrase
|
97 |
+
- A comma-separated list of numbers or strings
|
98 |
+
3. NO explanations or additional text is allowed
|
99 |
+
4. Strings should not have articles or abbreviations
|
100 |
+
5. Numbers should be in digits without commas
|
101 |
+
|
102 |
+
Does the answer meet these requirements and correctly answer the question?"""
|
103 |
+
|
104 |
+
# --- Final Answer format for GAIA benchmark -------------------------------
|
105 |
+
|
106 |
+
FINAL_ANSWER_PROMPT = """You are a response formatter for a GAIA benchmark question.
|
107 |
+
|
108 |
+
Your only job is to format the final answer in the exact format required: "FINAL ANSWER: [concise response]"
|
109 |
+
|
110 |
+
Requirements for [concise response]:
|
111 |
+
1. Response must ONLY be one of these formats:
|
112 |
+
- A single number (no commas, no units like $ or % unless specified)
|
113 |
+
- A single word or very short phrase
|
114 |
+
- A comma-separated list of numbers or strings
|
115 |
+
2. DO NOT include any explanations, reasoning, or extra text
|
116 |
+
3. For strings, don't use articles or abbreviations unless specified
|
117 |
+
4. For numbers, write digits (not spelled out) without commas
|
118 |
+
5. The response should be as concise as possible while being correct
|
119 |
+
|
120 |
+
Original question: {question}
|
121 |
+
|
122 |
+
Information available:
|
123 |
+
{context}
|
124 |
+
|
125 |
+
After reviewing the information, extract just the essential answer and output ONLY:
|
126 |
+
FINAL ANSWER: [your concise response]
|
127 |
+
"""
|
128 |
+
|
129 |
+
# --- Final Answer user prompt ---------------------------------------------
|
130 |
+
|
131 |
+
FINAL_ANSWER_USER_PROMPT = """Original question: {question}
|
132 |
+
|
133 |
+
Information available:
|
134 |
+
{context}
|
135 |
+
|
136 |
+
Remember to output ONLY 'FINAL ANSWER: [your concise response]' with no explanations."""
|
137 |
+
|
138 |
+
# --- Worker agent prompts -------------------------------------------------
|
139 |
+
|
140 |
+
RESEARCHER_PROMPT = """You are a research specialist focused on finding information and providing context.
|
141 |
+
|
142 |
+
Your key responsibilities:
|
143 |
+
1. Search for accurate, up-to-date information on any topic
|
144 |
+
2. Provide factual knowledge about products, concepts, and terminology
|
145 |
+
3. Explain real-world contexts and background information
|
146 |
+
4. Identify relevant parameters and variables needed for calculations
|
147 |
+
5. Present information clearly with proper citations
|
148 |
+
|
149 |
+
DO NOT perform complex calculations or coding tasks - these will be handled by the coder agent.
|
150 |
+
You MAY provide simple arithmetic or basic formulas to illustrate concepts.
|
151 |
+
|
152 |
+
Always return information in a structured, organized format that will be useful for the next steps.
|
153 |
+
|
154 |
+
System time: {system_time}
|
155 |
+
"""
|
156 |
+
|
157 |
+
CODER_PROMPT = """You are a computational specialist focused on calculations, coding, and data analysis.
|
158 |
+
|
159 |
+
Your key responsibilities:
|
160 |
+
1. Write and execute Python code for calculations and data manipulation
|
161 |
+
2. Perform precise numerical analyses based on inputs from the researcher
|
162 |
+
3. Format results clearly with appropriate units and precision
|
163 |
+
4. Use markdown to structure your response with headings and bullet points
|
164 |
+
5. Verify calculations through multiple methods when possible
|
165 |
+
|
166 |
+
Important:
|
167 |
+
1. Always include both your calculation process AND final result values
|
168 |
+
2. Always clearly state your assumptions when making calculations
|
169 |
+
3. Format numerical results with appropriate precision and units
|
170 |
+
4. When receiving data from the researcher, acknowledge and build upon it directly
|
171 |
+
5. If calculation involves multiple steps or cases, organize them with headings
|
172 |
+
|
173 |
+
System time: {system_time}
|
174 |
+
"""
|
175 |
+
|
176 |
+
# --- Legacy system prompt (kept for backward compatibility) ---------------
|
177 |
+
|
178 |
+
SYSTEM_PROMPT = """You are a helpful AI assistant.
|
179 |
+
|
180 |
+
System time: {system_time}"""
|
state.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Define the state structures for the agent supervisor."""
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
from typing import Dict, List, Literal, Optional, Sequence, Any
|
6 |
+
|
7 |
+
from langchain_core.messages import AnyMessage
|
8 |
+
from langgraph.graph import MessagesState, add_messages
|
9 |
+
from typing_extensions import TypedDict, Annotated
|
10 |
+
|
11 |
+
|
12 |
+
# --- Constants and shared definitions ---------------------------------------
|
13 |
+
|
14 |
+
# Define worker types (specialized agents that perform tasks)
|
15 |
+
WORKERS = ["researcher", "coder"]
|
16 |
+
|
17 |
+
# Define all member types (including control nodes)
|
18 |
+
MEMBERS = WORKERS + ["planner", "critic", "supervisor"]
|
19 |
+
|
20 |
+
# Define status/routing options
|
21 |
+
VERDICTS = ["CORRECT", "RETRY"]
|
22 |
+
ROUTING = ["FINISH"] + WORKERS
|
23 |
+
OPTIONS = ROUTING + VERDICTS
|
24 |
+
|
25 |
+
|
26 |
+
# --- Router for supervisor decisions ---------------------------------------
|
27 |
+
|
28 |
+
class Router(TypedDict):
|
29 |
+
"""Determines which worker to route to next or if the task is complete.
|
30 |
+
|
31 |
+
The supervisor returns this structure to navigate the workflow.
|
32 |
+
Valid values are defined in the ROUTING list.
|
33 |
+
"""
|
34 |
+
next: Literal[*ROUTING]
|
35 |
+
|
36 |
+
|
37 |
+
# --- Plan structure for the Planner node -----------------------------------
|
38 |
+
|
39 |
+
class PlanStep(TypedDict):
|
40 |
+
"""A single step in the plan created by the Planner."""
|
41 |
+
worker: Literal[*WORKERS]
|
42 |
+
instruction: str
|
43 |
+
|
44 |
+
|
45 |
+
class Plan(TypedDict):
|
46 |
+
"""The complete plan produced by the Planner node."""
|
47 |
+
steps: List[PlanStep]
|
48 |
+
|
49 |
+
|
50 |
+
# --- Critic verdict structure ----------------------------------------------
|
51 |
+
|
52 |
+
class CriticVerdict(TypedDict):
|
53 |
+
"""The verdict from the Critic on whether the answer is satisfactory."""
|
54 |
+
verdict: Literal[*VERDICTS]
|
55 |
+
reason: Optional[str]
|
56 |
+
|
57 |
+
|
58 |
+
# --- State for the agent supervisor ----------------------------------------
|
59 |
+
|
60 |
+
class State(MessagesState):
|
61 |
+
"""State for the agent supervisor workflow.
|
62 |
+
|
63 |
+
Extends MessagesState which provides message history tracking.
|
64 |
+
Adds fields to track routing information, plan, and critic verdict.
|
65 |
+
"""
|
66 |
+
next: str
|
67 |
+
plan: Optional[Plan] = None
|
68 |
+
current_step_index: Optional[int] = None
|
69 |
+
draft_answer: Optional[str] = None
|
70 |
+
critic_verdict: Optional[CriticVerdict] = None
|
71 |
+
context: Dict[str, Any] = {} # Shared context accessible to all agents
|
72 |
+
worker_results: Dict[str, List[str]] = {} # Store results from each worker
|
supervisor_node.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Supervisor node implementation for the agent supervisor system."""
|
2 |
+
|
3 |
+
from typing import Dict, List, Literal, Optional, Union, Type, cast
|
4 |
+
|
5 |
+
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
|
6 |
+
from langchain_core.prompts import ChatPromptTemplate
|
7 |
+
from langgraph.graph import StateGraph, START, END
|
8 |
+
from langgraph.types import Command
|
9 |
+
|
10 |
+
from react_agent.configuration import Configuration
|
11 |
+
from react_agent.state import WORKERS, MEMBERS, ROUTING, VERDICTS, State, Router
|
12 |
+
from react_agent.utils import load_chat_model, format_system_prompt, get_message_text
|
13 |
+
from react_agent import prompts
|
14 |
+
|
15 |
+
|
16 |
+
# Compile-time type definitions
|
17 |
+
SupervisorDestinations = Literal["planner", "critic", "researcher", "coder", "final_answer", "__end__"]
|
18 |
+
|
19 |
+
|
20 |
+
def supervisor_node(state: State) -> Command[SupervisorDestinations]:
|
21 |
+
"""Supervising LLM that decides which specialized agent should act next.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
state: The current state with messages
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
Command with routing information
|
28 |
+
"""
|
29 |
+
# Get configuration to use supervisor_model
|
30 |
+
configuration = Configuration.from_context()
|
31 |
+
|
32 |
+
# Track steps to prevent infinite loops
|
33 |
+
steps_taken = state.get("steps_taken", 0)
|
34 |
+
steps_taken += 1
|
35 |
+
state_updates = {"steps_taken": steps_taken}
|
36 |
+
|
37 |
+
# Check if we've hit our step limit
|
38 |
+
if steps_taken >= configuration.recursion_limit - 5: # Buffer of 5 steps
|
39 |
+
# Extract the best answer we have from context if possible
|
40 |
+
context = state.get("context", {})
|
41 |
+
answer = extract_best_answer_from_context(context)
|
42 |
+
|
43 |
+
return Command(
|
44 |
+
goto="final_answer",
|
45 |
+
update={
|
46 |
+
"messages": [
|
47 |
+
HumanMessage(
|
48 |
+
content=f"Maximum steps ({steps_taken}) reached. Extracting best answer from available information.",
|
49 |
+
name="supervisor"
|
50 |
+
)
|
51 |
+
],
|
52 |
+
"draft_answer": f"FINAL ANSWER: {answer}",
|
53 |
+
"retry_exhausted": True, # Flag to indicate we've exhausted retries
|
54 |
+
"steps_taken": steps_taken
|
55 |
+
}
|
56 |
+
)
|
57 |
+
|
58 |
+
# Safety check - prevent infinite loops by forcing termination after too many retry steps
|
59 |
+
retry_count = state.get("retry_count", 0)
|
60 |
+
max_retries = 2 # Maximum number of allowed retries
|
61 |
+
|
62 |
+
if retry_count > max_retries:
|
63 |
+
# Extract the best answer we have from context if possible
|
64 |
+
context = state.get("context", {})
|
65 |
+
answer = extract_best_answer_from_context(context)
|
66 |
+
|
67 |
+
return Command(
|
68 |
+
goto="final_answer",
|
69 |
+
update={
|
70 |
+
"messages": [
|
71 |
+
HumanMessage(
|
72 |
+
content=f"Maximum retries ({max_retries}) reached. Extracting best answer from available information.",
|
73 |
+
name="supervisor"
|
74 |
+
)
|
75 |
+
],
|
76 |
+
"draft_answer": f"FINAL ANSWER: {answer}",
|
77 |
+
"retry_exhausted": True, # Flag to indicate we've exhausted retries
|
78 |
+
"steps_taken": steps_taken
|
79 |
+
}
|
80 |
+
)
|
81 |
+
|
82 |
+
# Check if we need a plan
|
83 |
+
if not state.get("plan"):
|
84 |
+
return Command(
|
85 |
+
goto="planner",
|
86 |
+
update={
|
87 |
+
**state_updates
|
88 |
+
}
|
89 |
+
)
|
90 |
+
|
91 |
+
# Validate that the plan has at least one step
|
92 |
+
plan = state.get("plan")
|
93 |
+
if not plan.get("steps") or len(plan.get("steps", [])) == 0:
|
94 |
+
# Plan has no steps, go back to planner with explicit instructions
|
95 |
+
return Command(
|
96 |
+
goto="planner",
|
97 |
+
update={
|
98 |
+
"messages": [
|
99 |
+
HumanMessage(
|
100 |
+
content="Previous plan had 0 steps. Please create a plan with at least 1 step to solve the user's question.",
|
101 |
+
name="supervisor"
|
102 |
+
)
|
103 |
+
],
|
104 |
+
"plan": None,
|
105 |
+
**state_updates
|
106 |
+
}
|
107 |
+
)
|
108 |
+
|
109 |
+
# Check if we have a critic verdict that requires replanning
|
110 |
+
critic_verdict = state.get("critic_verdict")
|
111 |
+
if critic_verdict:
|
112 |
+
if critic_verdict.get("verdict") == VERDICTS[0]: # CORRECT
|
113 |
+
# Final answer is approved, navigate to the final_answer node
|
114 |
+
# This will generate a polished response before ending
|
115 |
+
return Command(
|
116 |
+
goto="final_answer",
|
117 |
+
update={
|
118 |
+
"messages": [
|
119 |
+
HumanMessage(
|
120 |
+
content="Answer approved by critic. Generating final response.",
|
121 |
+
name="supervisor"
|
122 |
+
)
|
123 |
+
]
|
124 |
+
}
|
125 |
+
)
|
126 |
+
elif critic_verdict.get("verdict") == VERDICTS[1]: # RETRY
|
127 |
+
# IMPORTANT: Get the current retry count BEFORE incrementing
|
128 |
+
current_retry_count = state.get("retry_count", 0)
|
129 |
+
|
130 |
+
# Check if we're at the maximum allowed retries
|
131 |
+
if current_retry_count >= max_retries:
|
132 |
+
# Extract best answer and go to final_answer
|
133 |
+
context = state.get("context", {})
|
134 |
+
answer = extract_best_answer_from_context(context)
|
135 |
+
|
136 |
+
return Command(
|
137 |
+
goto="final_answer",
|
138 |
+
update={
|
139 |
+
"messages": [
|
140 |
+
HumanMessage(
|
141 |
+
content=f"Maximum retries ({max_retries}) reached. Proceeding with best available answer.",
|
142 |
+
name="supervisor"
|
143 |
+
)
|
144 |
+
],
|
145 |
+
"draft_answer": f"FINAL ANSWER: {answer}",
|
146 |
+
"retry_exhausted": True # Flag to indicate we've exhausted retries
|
147 |
+
}
|
148 |
+
)
|
149 |
+
|
150 |
+
# Reset the plan but KEEP the context from previous iterations
|
151 |
+
context = state.get("context", {})
|
152 |
+
worker_results = state.get("worker_results", {})
|
153 |
+
|
154 |
+
# Get the critic's reason for rejection, if any
|
155 |
+
reason = critic_verdict.get("reason", "")
|
156 |
+
if not reason or reason.strip() == "\"":
|
157 |
+
reason = "Answer did not meet format requirements"
|
158 |
+
|
159 |
+
# Check if this is a formatting issue
|
160 |
+
format_issues = [
|
161 |
+
"format", "concise", "explanation", "not formatted",
|
162 |
+
"instead of just", "contains explanations", "FINAL ANSWER"
|
163 |
+
]
|
164 |
+
is_format_issue = any(issue in reason.lower() for issue in format_issues)
|
165 |
+
|
166 |
+
# If we have enough information but the format is wrong, go directly to final answer
|
167 |
+
has_sufficient_info = has_sufficient_information(state)
|
168 |
+
|
169 |
+
if is_format_issue and has_sufficient_info and current_retry_count >= 0:
|
170 |
+
# We have information but formatting is wrong - skip planning and go to final answer
|
171 |
+
return Command(
|
172 |
+
goto="final_answer",
|
173 |
+
update={
|
174 |
+
"messages": [
|
175 |
+
HumanMessage(
|
176 |
+
content="We have sufficient information but formatting issues. Generating properly formatted answer.",
|
177 |
+
name="supervisor"
|
178 |
+
)
|
179 |
+
],
|
180 |
+
"retry_count": current_retry_count + 1 # Still increment retry count
|
181 |
+
}
|
182 |
+
)
|
183 |
+
|
184 |
+
# Increment the retry counter
|
185 |
+
next_retry_count = current_retry_count + 1
|
186 |
+
|
187 |
+
return Command(
|
188 |
+
goto="planner",
|
189 |
+
update={
|
190 |
+
"plan": None,
|
191 |
+
"current_step_index": None,
|
192 |
+
"draft_answer": None,
|
193 |
+
"critic_verdict": None,
|
194 |
+
# Keep the context and worker_results
|
195 |
+
"context": context,
|
196 |
+
"worker_results": worker_results,
|
197 |
+
# Track retries - IMPORTANT: store the incremented count
|
198 |
+
"retry_count": next_retry_count,
|
199 |
+
# Add a message about the retry (using the INCREMENTED count)
|
200 |
+
"messages": [
|
201 |
+
HumanMessage(
|
202 |
+
content=f"Retrying with new plan (retry #{next_retry_count}). Reason: {reason}",
|
203 |
+
name="supervisor"
|
204 |
+
)
|
205 |
+
]
|
206 |
+
}
|
207 |
+
)
|
208 |
+
|
209 |
+
# Get the current step from the plan
|
210 |
+
plan = state["plan"]
|
211 |
+
current_step_index = state.get("current_step_index", 0)
|
212 |
+
|
213 |
+
# Check if we've completed all steps
|
214 |
+
if current_step_index >= len(plan["steps"]):
|
215 |
+
# Use context to compile the draft answer
|
216 |
+
context = state.get("context", {})
|
217 |
+
|
218 |
+
# Combine the most recent worker outputs as the draft answer
|
219 |
+
worker_results = []
|
220 |
+
for worker in WORKERS:
|
221 |
+
if worker in context:
|
222 |
+
worker_results.append(f"**{worker.title()}**: {context[worker]}")
|
223 |
+
|
224 |
+
# Compile the draft answer from all worker outputs
|
225 |
+
draft_content = "\n\n".join(worker_results)
|
226 |
+
|
227 |
+
# Send to the critic for evaluation
|
228 |
+
return Command(
|
229 |
+
goto="critic",
|
230 |
+
update={
|
231 |
+
"draft_answer": draft_content,
|
232 |
+
# Add a message about moving to evaluation
|
233 |
+
"messages": [
|
234 |
+
HumanMessage(
|
235 |
+
content="All steps completed. Evaluating the answer.",
|
236 |
+
name="supervisor"
|
237 |
+
)
|
238 |
+
]
|
239 |
+
}
|
240 |
+
)
|
241 |
+
|
242 |
+
# Get the current step
|
243 |
+
current_step = plan["steps"][current_step_index]
|
244 |
+
worker = current_step["worker"]
|
245 |
+
instruction = current_step["instruction"]
|
246 |
+
|
247 |
+
# Extract only the most relevant context for the current worker and task
|
248 |
+
context_info = ""
|
249 |
+
if state.get("context"):
|
250 |
+
# Filter context by relevance to the current task
|
251 |
+
relevant_context = {}
|
252 |
+
|
253 |
+
# For the coder, extract numerical data and parameters from researcher
|
254 |
+
if worker == "coder" and "researcher" in state["context"]:
|
255 |
+
relevant_context["researcher"] = state["context"]["researcher"]
|
256 |
+
|
257 |
+
# For the researcher, previous coder calculations might be relevant
|
258 |
+
if worker == "researcher" and "coder" in state["context"]:
|
259 |
+
# Only include numerical results from coder, not code snippets
|
260 |
+
coder_content = state["context"]["coder"]
|
261 |
+
if len(coder_content) < 100: # Only short results are likely just numbers
|
262 |
+
relevant_context["coder"] = coder_content
|
263 |
+
|
264 |
+
# Format the relevant context items
|
265 |
+
context_items = []
|
266 |
+
for key, value in relevant_context.items():
|
267 |
+
# Summarize if value is too long
|
268 |
+
if len(value) > 200:
|
269 |
+
# Find first sentence or up to 200 chars
|
270 |
+
summary = value[:200]
|
271 |
+
if '.' in summary:
|
272 |
+
summary = summary.split('.')[0] + '.'
|
273 |
+
context_items.append(f"Previous {key} found: {summary}...")
|
274 |
+
else:
|
275 |
+
context_items.append(f"Previous {key} found: {value}")
|
276 |
+
|
277 |
+
if context_items:
|
278 |
+
context_info = "\n\nRelevant context: " + "\n".join(context_items)
|
279 |
+
|
280 |
+
# Enhance the instruction with context
|
281 |
+
enhanced_instruction = f"{instruction}{context_info}"
|
282 |
+
|
283 |
+
# Add guidance based on worker type
|
284 |
+
if worker == "coder":
|
285 |
+
enhanced_instruction += "\nProvide both your calculation method AND the final result value."
|
286 |
+
elif worker == "researcher":
|
287 |
+
enhanced_instruction += "\nFocus on gathering factual information related to the task."
|
288 |
+
|
289 |
+
# Add the instruction to the messages
|
290 |
+
messages_update = [
|
291 |
+
HumanMessage(
|
292 |
+
content=f"Step {current_step_index + 1}: {enhanced_instruction}",
|
293 |
+
name="supervisor"
|
294 |
+
)
|
295 |
+
]
|
296 |
+
|
297 |
+
# Cast worker to appropriate type to satisfy type checking
|
298 |
+
worker_destination = cast(SupervisorDestinations, worker)
|
299 |
+
|
300 |
+
# Move to the appropriate worker
|
301 |
+
return Command(
|
302 |
+
goto=worker_destination,
|
303 |
+
update={
|
304 |
+
"messages": messages_update,
|
305 |
+
"next": worker, # For backward compatibility
|
306 |
+
**state_updates
|
307 |
+
}
|
308 |
+
)
|
309 |
+
|
310 |
+
def extract_best_answer_from_context(context):
|
311 |
+
"""Extract the best available answer from context.
|
312 |
+
|
313 |
+
This is a generic function to extract answers from any type of question context.
|
314 |
+
It progressively tries different strategies to find a suitable answer.
|
315 |
+
|
316 |
+
Args:
|
317 |
+
context: The state context containing worker outputs
|
318 |
+
|
319 |
+
Returns:
|
320 |
+
Best answer found or "unknown" if nothing suitable is found
|
321 |
+
"""
|
322 |
+
answer = "unknown"
|
323 |
+
|
324 |
+
# First check if the coder already provided a properly formatted answer
|
325 |
+
if "coder" in context:
|
326 |
+
coder_content = context["coder"]
|
327 |
+
|
328 |
+
# Look for "FINAL ANSWER: X" pattern in the coder output
|
329 |
+
import re
|
330 |
+
answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", coder_content, re.IGNORECASE)
|
331 |
+
if answer_match:
|
332 |
+
return answer_match.group(1).strip()
|
333 |
+
|
334 |
+
# If no answer in coder output, check researcher content
|
335 |
+
if "researcher" in context:
|
336 |
+
researcher_content = context["researcher"]
|
337 |
+
|
338 |
+
# Look for lists in the researcher content (common pattern)
|
339 |
+
import re
|
340 |
+
|
341 |
+
# Look for bulleted list items
|
342 |
+
list_items = re.findall(r"[-•*]\s+([^:\n]+)", researcher_content)
|
343 |
+
if list_items:
|
344 |
+
# Format as comma-separated list
|
345 |
+
answer = ",".join(item.strip() for item in list_items)
|
346 |
+
return answer
|
347 |
+
|
348 |
+
# Look for emphasized/bold items which might be key information
|
349 |
+
bold_items = re.findall(r"\*\*([^*]+)\*\*", researcher_content)
|
350 |
+
if bold_items:
|
351 |
+
# Join the important items as a comma-separated list
|
352 |
+
processed_items = []
|
353 |
+
for item in bold_items:
|
354 |
+
# Remove common filler words and clean up the item
|
355 |
+
clean_item = re.sub(r'(^|\s)(a|an|the|is|are|was|were|be|been)(\s|$)', ' ', item)
|
356 |
+
clean_item = clean_item.strip()
|
357 |
+
if clean_item and len(clean_item) < 30: # Only include reasonably short items
|
358 |
+
processed_items.append(clean_item)
|
359 |
+
|
360 |
+
if processed_items:
|
361 |
+
answer = ",".join(processed_items)
|
362 |
+
return answer
|
363 |
+
|
364 |
+
# If we still don't have an answer, try to extract common entities
|
365 |
+
combined_content = ""
|
366 |
+
for worker_type, content in context.items():
|
367 |
+
combined_content += " " + content
|
368 |
+
|
369 |
+
# Look for numbers in the content
|
370 |
+
import re
|
371 |
+
numbers = re.findall(r'\b(\d+(?:\.\d+)?)\b', combined_content)
|
372 |
+
if numbers:
|
373 |
+
answer = numbers[0] # Use the first number found
|
374 |
+
|
375 |
+
return answer
|
376 |
+
|
377 |
+
def has_sufficient_information(state):
|
378 |
+
"""Determine if we have enough information to generate a final answer.
|
379 |
+
|
380 |
+
Args:
|
381 |
+
state: The current conversation state
|
382 |
+
|
383 |
+
Returns:
|
384 |
+
Boolean indicating if we have sufficient information
|
385 |
+
"""
|
386 |
+
context = state.get("context", {})
|
387 |
+
|
388 |
+
# If we have both researcher and coder outputs, we likely have enough info
|
389 |
+
if "researcher" in context and "coder" in context:
|
390 |
+
return True
|
391 |
+
|
392 |
+
# If we have a substantial researcher output, that might be enough
|
393 |
+
if "researcher" in context and len(context["researcher"]) > 150:
|
394 |
+
return True
|
395 |
+
|
396 |
+
# If we have any worker output that contains lists or formatted data
|
397 |
+
for worker, content in context.items():
|
398 |
+
if content and (
|
399 |
+
"- " in content or # Bullet point
|
400 |
+
"•" in content or # Bullet point
|
401 |
+
"*" in content or # Emphasis or bullet
|
402 |
+
":" in content # Definition or explanation
|
403 |
+
):
|
404 |
+
return True
|
405 |
+
|
406 |
+
return False
|
tools.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module provides tools for the agent supervisor.
|
2 |
+
|
3 |
+
It includes:
|
4 |
+
- Web Search: For general web results using Tavily.
|
5 |
+
- Python REPL: For executing Python code (Use with caution!).
|
6 |
+
"""
|
7 |
+
|
8 |
+
from typing import Annotated, List, Any, Callable, Optional, cast
|
9 |
+
|
10 |
+
# Core Tools & Utilities
|
11 |
+
from langchain_core.tools import tool
|
12 |
+
|
13 |
+
# Experimental Tools (Use with caution)
|
14 |
+
from langchain_experimental.utilities import PythonREPL
|
15 |
+
|
16 |
+
# Use TavilySearchResults from langchain_community like in the notebook
|
17 |
+
from langchain_community.tools.tavily_search import TavilySearchResults
|
18 |
+
from react_agent.configuration import Configuration
|
19 |
+
|
20 |
+
|
21 |
+
# Create Tavily tool using configuration from context (more consistent approach)
|
22 |
+
def create_tavily_tool():
|
23 |
+
"""Create the Tavily search tool with configuration from context.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
Configured TavilySearchResults tool
|
27 |
+
"""
|
28 |
+
configuration = Configuration.from_context()
|
29 |
+
return TavilySearchResults(max_results=configuration.max_search_results)
|
30 |
+
|
31 |
+
# Initialize the tool
|
32 |
+
tavily_tool = create_tavily_tool()
|
33 |
+
|
34 |
+
|
35 |
+
# --- Python REPL Tool ---
|
36 |
+
# WARNING: Executes arbitrary Python code locally. Be extremely careful
|
37 |
+
# about exposing this tool, especially in production environments.
|
38 |
+
repl = PythonREPL()
|
39 |
+
|
40 |
+
@tool
|
41 |
+
def python_repl_tool(
|
42 |
+
code: Annotated[str, "The python code to execute. Use print(...) to see output."],
|
43 |
+
):
|
44 |
+
"""Use this to execute python code. If you want to see the output of a value,
|
45 |
+
you should print it out with `print(...)`. This is visible to the user."""
|
46 |
+
try:
|
47 |
+
result = repl.run(code)
|
48 |
+
except BaseException as e:
|
49 |
+
return f"Failed to execute. Error: {repr(e)}"
|
50 |
+
# Filter out potentially sensitive REPL implementation details
|
51 |
+
result_str = f"Successfully executed:\n\`\`\`python\n{code}\n\`\`\`\nStdout: {result}"
|
52 |
+
return result_str
|
53 |
+
|
54 |
+
|
55 |
+
# --- Tool List ---
|
56 |
+
|
57 |
+
# The list of tools available to the agent supervisor.
|
58 |
+
TOOLS: List[Callable[..., Any]] = [tavily_tool, python_repl_tool]
|
utils.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility & helper functions."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
from langchain.chat_models import init_chat_model
|
6 |
+
from langchain_core.language_models import BaseChatModel
|
7 |
+
from langchain_core.messages import BaseMessage
|
8 |
+
import asyncio
|
9 |
+
from datetime import UTC, datetime
|
10 |
+
from react_agent.state import WORKERS, MEMBERS, ROUTING, VERDICTS
|
11 |
+
|
12 |
+
|
13 |
+
# Load environment variables from .env file
|
14 |
+
load_dotenv()
|
15 |
+
|
16 |
+
|
17 |
+
def get_message_text(msg: BaseMessage) -> str:
|
18 |
+
"""Get the text content of a message."""
|
19 |
+
content = msg.content
|
20 |
+
if isinstance(content, str):
|
21 |
+
return content
|
22 |
+
elif isinstance(content, dict):
|
23 |
+
return content.get("text", "")
|
24 |
+
else:
|
25 |
+
txts = [c if isinstance(c, str) else (c.get("text") or "") for c in content]
|
26 |
+
return "".join(txts).strip()
|
27 |
+
|
28 |
+
|
29 |
+
def format_system_prompt(prompt_template: str) -> str:
|
30 |
+
"""Format a system prompt template with current system time and available agents.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
prompt_template: The prompt template to format
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
The formatted prompt with system time and agent information
|
37 |
+
"""
|
38 |
+
# Get example workers for templates
|
39 |
+
example_worker_1 = WORKERS[0] if WORKERS else "researcher"
|
40 |
+
example_worker_2 = WORKERS[1] if len(WORKERS) > 1 else "coder"
|
41 |
+
|
42 |
+
# Get verdicts for templates
|
43 |
+
correct_verdict = VERDICTS[0] if VERDICTS else "CORRECT"
|
44 |
+
retry_verdict = VERDICTS[1] if len(VERDICTS) > 1 else "RETRY"
|
45 |
+
|
46 |
+
return prompt_template.format(
|
47 |
+
system_time=datetime.now(tz=UTC).isoformat(),
|
48 |
+
workers=", ".join(WORKERS),
|
49 |
+
members=", ".join(MEMBERS),
|
50 |
+
worker_options=", ".join([f'"{w}"' for w in WORKERS]),
|
51 |
+
example_worker_1=example_worker_1,
|
52 |
+
example_worker_2=example_worker_2,
|
53 |
+
correct_verdict=correct_verdict,
|
54 |
+
retry_verdict=retry_verdict
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
def load_chat_model(fully_specified_name: str) -> BaseChatModel:
|
59 |
+
"""Load a chat model from a fully specified name.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
fully_specified_name (str): String in the format 'provider/model'.
|
63 |
+
"""
|
64 |
+
provider, model = fully_specified_name.split("/", maxsplit=1)
|
65 |
+
|
66 |
+
# Special handling for Google Genai models to ensure they're configured for async
|
67 |
+
if provider == "google_genai":
|
68 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
69 |
+
|
70 |
+
# Make sure we have the API key
|
71 |
+
if not os.environ.get("GOOGLE_API_KEY"):
|
72 |
+
raise ValueError("GOOGLE_API_KEY environment variable is required for google_genai models")
|
73 |
+
|
74 |
+
return ChatGoogleGenerativeAI(model=model)
|
75 |
+
else:
|
76 |
+
return init_chat_model(model, model_provider=provider)
|