sims2k commited on
Commit
ac6a4ef
·
verified ·
1 Parent(s): 375e6bb

Upload 8 files

Browse files
Files changed (8) hide show
  1. app.py +375 -0
  2. configuration.py +166 -0
  3. graph.py +574 -0
  4. prompts.py +180 -0
  5. state.py +72 -0
  6. supervisor_node.py +406 -0
  7. tools.py +58 -0
  8. utils.py +76 -0
app.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Web application for the Agent Supervisor with GAIA benchmark integration.
2
+
3
+ This module provides a Gradio web interface for interacting with the Agent Supervisor
4
+ and evaluating it against the GAIA benchmark.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import uuid
10
+ import asyncio
11
+ import requests
12
+ import pandas as pd
13
+ import gradio as gr
14
+
15
+ from typing import Dict, List, Optional
16
+ from langchain_core.messages import HumanMessage
17
+ from langgraph.checkpoint.memory import MemorySaver
18
+
19
+ from react_agent.graph import create_agent_supervisor_graph, get_compiled_graph
20
+
21
+ # --- Constants ---
22
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
23
+
24
+ class GaiaAgent:
25
+ """Agent implementation for the GAIA benchmark using the LangGraph supervisor."""
26
+
27
+ def __init__(self, model_name=None, checkpointer=None):
28
+ """Initialize the GAIA agent with LangGraph architecture.
29
+
30
+ Args:
31
+ model_name: Optional model name to override the default
32
+ checkpointer: Optional checkpointer for persistence
33
+ """
34
+ print("Initializing GaiaAgent...")
35
+
36
+ # Import Configuration class
37
+ from react_agent.configuration import Configuration
38
+
39
+ # Get configuration
40
+ config = Configuration.from_context()
41
+ default_model = config.model
42
+
43
+ # If no checkpointer provided, create a default one - using MemorySaver to avoid SQLite thread issues
44
+ if checkpointer is None:
45
+ from langgraph.checkpoint.memory import MemorySaver
46
+ checkpointer = MemorySaver()
47
+ print("Using in-memory checkpointer to avoid thread safety issues")
48
+
49
+ # Create and compile the graph
50
+ self.graph = get_compiled_graph(checkpointer=checkpointer)
51
+
52
+ # Configure the agent using values from Configuration
53
+ self.config = {
54
+ "configurable": {
55
+ # Use configuration model or override if provided
56
+ "model": model_name if model_name else default_model,
57
+ # Import specific models for each role from Configuration
58
+ "researcher_model": config.researcher_model,
59
+ "coder_model": config.coder_model,
60
+ "planner_model": config.planner_model,
61
+ "supervisor_model": config.supervisor_model,
62
+ "critic_model": config.critic_model,
63
+ "final_answer_model": config.final_answer_model,
64
+ # Other settings from Configuration
65
+ "max_search_results": config.max_search_results,
66
+ "recursion_limit": config.recursion_limit,
67
+ "max_iterations": config.max_iterations,
68
+ "allow_agent_to_extract_answers": config.allow_agent_to_extract_answers
69
+ }
70
+ }
71
+
72
+ print(f"GaiaAgent initialized successfully with model: {self.config['configurable']['model']}")
73
+
74
+ def __call__(self, question: str) -> str:
75
+ """Process a question and return an answer formatted for GAIA benchmark.
76
+
77
+ Args:
78
+ question: The GAIA benchmark question
79
+
80
+ Returns:
81
+ Answer formatted for GAIA benchmark evaluation
82
+ """
83
+ print(f"Agent received question: {question[:100]}...")
84
+
85
+ # Create a thread_id for this interaction
86
+ thread_id = str(uuid.uuid4())
87
+ self.config["configurable"]["thread_id"] = thread_id
88
+
89
+ # Import configuration
90
+ from react_agent.configuration import Configuration
91
+ config = Configuration.from_context()
92
+
93
+ # Add a system prompt to ensure proper GAIA format
94
+ system_prompt = """You are a general AI assistant. Answer the question concisely.
95
+ YOUR ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
96
+ If asked for a number, don't use commas or units like $ or % unless specified.
97
+ If asked for a string, don't use articles or abbreviations (e.g. for cities), and write digits as plain text unless specified otherwise.
98
+ Focus on brevity and correctness."""
99
+
100
+ # Create input state with the human message and system prompt
101
+ input_state = {
102
+ "messages": [HumanMessage(content=question)],
103
+ "configurable": {
104
+ "thread_id": thread_id,
105
+ "system_prompt": system_prompt,
106
+ "model": config.model # Ensure model is also set in the state
107
+ }
108
+ }
109
+
110
+ # Process the question with our graph
111
+ try:
112
+ # Execute the graph and get the final state
113
+ # Use invoke instead of stream to limit operations
114
+ try:
115
+ final_state = self.graph.invoke(input_state, config=self.config)
116
+ except Exception as e:
117
+ # If we hit recursion error, try again with higher limit
118
+ print(f"Initial invocation failed: {str(e)}")
119
+ # Use double the recursion limit as fallback
120
+ self.config["configurable"]["recursion_limit"] = config.recursion_limit * 2
121
+ final_state = self.graph.invoke(input_state, config=self.config)
122
+
123
+ # Extract the answer - either from gaia_answer or from the last message
124
+ if "gaia_answer" in final_state:
125
+ answer = final_state["gaia_answer"]
126
+ else:
127
+ messages = final_state.get("messages", [])
128
+ answer = messages[-1].content if messages else "No answer generated."
129
+
130
+ # Clean the answer to ensure proper GAIA format (remove any FINAL ANSWER prefix)
131
+ if "FINAL ANSWER:" in answer:
132
+ answer = answer.split("FINAL ANSWER:")[1].strip()
133
+
134
+ print(f"Agent returning answer: {answer[:100]}...")
135
+ return answer
136
+
137
+ except Exception as e:
138
+ error_msg = f"Error processing question: {str(e)}"
139
+ print(error_msg)
140
+ return error_msg
141
+
142
+
143
+ def run_and_submit_all(profile: gr.OAuthProfile | None):
144
+ """Fetches all questions, runs the GaiaAgent on them, submits answers, and displays the results."""
145
+
146
+ # --- Determine HF Space Runtime URL and Repo URL ---
147
+ space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
148
+
149
+ if profile:
150
+ username = f"{profile.username}"
151
+ print(f"User logged in: {username}")
152
+ else:
153
+ print("User not logged in.")
154
+ return "Please Login to Hugging Face with the button.", None
155
+
156
+ api_url = DEFAULT_API_URL
157
+ questions_url = f"{api_url}/questions"
158
+ submit_url = f"{api_url}/submit"
159
+
160
+ # 1. Instantiate Agent
161
+ try:
162
+ agent = GaiaAgent()
163
+ except Exception as e:
164
+ print(f"Error instantiating agent: {e}")
165
+ return f"Error initializing agent: {e}", None
166
+
167
+ # In the case of an app running as a hugging Face space, this link points toward your codebase
168
+ agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
169
+ print(agent_code)
170
+
171
+ # 2. Fetch Questions
172
+ print(f"Fetching questions from: {questions_url}")
173
+ try:
174
+ response = requests.get(questions_url, timeout=15)
175
+ response.raise_for_status()
176
+ questions_data = response.json()
177
+ if not questions_data:
178
+ print("Fetched questions list is empty.")
179
+ return "Fetched questions list is empty or invalid format.", None
180
+ print(f"Fetched {len(questions_data)} questions.")
181
+ except requests.exceptions.RequestException as e:
182
+ print(f"Error fetching questions: {e}")
183
+ return f"Error fetching questions: {e}", None
184
+ except requests.exceptions.JSONDecodeError as e:
185
+ print(f"Error decoding JSON response from questions endpoint: {e}")
186
+ print(f"Response text: {response.text[:500]}")
187
+ return f"Error decoding server response for questions: {e}", None
188
+ except Exception as e:
189
+ print(f"An unexpected error occurred fetching questions: {e}")
190
+ return f"An unexpected error occurred fetching questions: {e}", None
191
+
192
+ # 3. Run the Agent
193
+ results_log = []
194
+ answers_payload = []
195
+ print(f"Running agent on {len(questions_data)} questions...")
196
+ for item in questions_data:
197
+ task_id = item.get("task_id")
198
+ question_text = item.get("question")
199
+ if not task_id or question_text is None:
200
+ print(f"Skipping item with missing task_id or question: {item}")
201
+ continue
202
+ try:
203
+ answer = agent(question_text)
204
+ # Format answers according to API requirements - use submitted_answer as required
205
+ answers_payload.append({
206
+ "task_id": task_id,
207
+ "submitted_answer": answer
208
+ })
209
+ results_log.append({
210
+ "Task ID": task_id,
211
+ "Question": question_text,
212
+ "Answer": answer
213
+ })
214
+ except Exception as e:
215
+ print(f"Error running agent on task {task_id}: {e}")
216
+ results_log.append({
217
+ "Task ID": task_id,
218
+ "Question": question_text,
219
+ "Answer": f"AGENT ERROR: {e}"
220
+ })
221
+
222
+ if not answers_payload:
223
+ print("Agent did not produce any answers to submit.")
224
+ return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
225
+
226
+ # 4. Prepare Submission
227
+ submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
228
+ status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
229
+ print(status_update)
230
+
231
+ # 5. Submit
232
+ print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
233
+ try:
234
+ response = requests.post(submit_url, json=submission_data, timeout=60)
235
+ response.raise_for_status()
236
+ result_data = response.json()
237
+ final_status = (
238
+ f"Submission Successful!\n"
239
+ f"User: {result_data.get('username')}\n"
240
+ f"Overall Score: {result_data.get('score', 'N/A')}% "
241
+ f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
242
+ f"Message: {result_data.get('message', 'No message received.')}"
243
+ )
244
+ print("Submission successful.")
245
+ results_df = pd.DataFrame(results_log)
246
+ return final_status, results_df
247
+ except requests.exceptions.HTTPError as e:
248
+ error_detail = f"Server responded with status {e.response.status_code}."
249
+ try:
250
+ error_json = e.response.json()
251
+ error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
252
+ except requests.exceptions.JSONDecodeError:
253
+ error_detail += f" Response: {e.response.text[:500]}"
254
+ status_message = f"Submission Failed: {error_detail}"
255
+ print(status_message)
256
+ results_df = pd.DataFrame(results_log)
257
+ return status_message, results_df
258
+ except requests.exceptions.Timeout:
259
+ status_message = "Submission Failed: The request timed out."
260
+ print(status_message)
261
+ results_df = pd.DataFrame(results_log)
262
+ return status_message, results_df
263
+ except requests.exceptions.RequestException as e:
264
+ status_message = f"Submission Failed: Network error - {e}"
265
+ print(status_message)
266
+ results_df = pd.DataFrame(results_log)
267
+ return status_message, results_df
268
+ except Exception as e:
269
+ status_message = f"An unexpected error occurred during submission: {e}"
270
+ print(status_message)
271
+ results_df = pd.DataFrame(results_log)
272
+ return status_message, results_df
273
+
274
+
275
+ # Function to test a single random question
276
+ def test_random_question():
277
+ """Fetch a random question from the API and run the agent on it."""
278
+ api_url = DEFAULT_API_URL
279
+ random_question_url = f"{api_url}/random-question"
280
+
281
+ try:
282
+ # Fetch a random question
283
+ response = requests.get(random_question_url, timeout=15)
284
+ response.raise_for_status()
285
+ question_data = response.json()
286
+
287
+ if not question_data:
288
+ return "Error: Received empty response from random question endpoint.", None
289
+
290
+ task_id = question_data.get("task_id")
291
+ question_text = question_data.get("question")
292
+
293
+ if not task_id or not question_text:
294
+ return "Error: Invalid question format received.", None
295
+
296
+ # Initialize agent and get answer
297
+ agent = GaiaAgent()
298
+ answer = agent(question_text)
299
+
300
+ # Return results
301
+ result = {
302
+ "Task ID": task_id,
303
+ "Question": question_text,
304
+ "Answer": answer
305
+ }
306
+
307
+ return "Test completed successfully.", result
308
+
309
+ except Exception as e:
310
+ return f"Error testing random question: {str(e)}", None
311
+
312
+
313
+ # --- Build Gradio Interface using Blocks ---
314
+ with gr.Blocks() as demo:
315
+ gr.Markdown("# GAIA Benchmark Agent Evaluation")
316
+ gr.Markdown(
317
+ """
318
+ **Instructions:**
319
+
320
+ 1. Log in to your Hugging Face account using the button below. This uses your HF username for submission.
321
+ 2. Click 'Run Evaluation & Submit All Answers' to fetch questions, run the agent, submit answers, and see the score.
322
+ 3. Alternatively, click 'Test on Random Question' to test the agent on a single random question.
323
+
324
+ ---
325
+ **Note:** Running the agent on all questions may take some time. Please be patient while the agent processes all the questions.
326
+ """
327
+ )
328
+
329
+ gr.LoginButton()
330
+
331
+ with gr.Tabs():
332
+ with gr.TabItem("Full Evaluation"):
333
+ run_button = gr.Button("Run Evaluation & Submit All Answers")
334
+ status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
335
+ results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
336
+
337
+ run_button.click(
338
+ fn=run_and_submit_all,
339
+ outputs=[status_output, results_table]
340
+ )
341
+
342
+ with gr.TabItem("Test Single Question"):
343
+ test_button = gr.Button("Test on Random Question")
344
+ test_status = gr.Textbox(label="Test Status", lines=2, interactive=False)
345
+ test_result = gr.JSON(label="Question and Answer")
346
+
347
+ test_button.click(
348
+ fn=test_random_question,
349
+ outputs=[test_status, test_result]
350
+ )
351
+
352
+
353
+ if __name__ == "__main__":
354
+ print("\n" + "-"*30 + " App Starting " + "-"*30)
355
+ # Check for SPACE_HOST and SPACE_ID at startup for information
356
+ space_host_startup = os.getenv("SPACE_HOST")
357
+ space_id_startup = os.getenv("SPACE_ID") # Get SPACE_ID at startup
358
+
359
+ if space_host_startup:
360
+ print(f"✅ SPACE_HOST found: {space_host_startup}")
361
+ print(f" Runtime URL should be: https://{space_host_startup}.hf.space")
362
+ else:
363
+ print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
364
+
365
+ if space_id_startup: # Print repo URLs if SPACE_ID is found
366
+ print(f"✅ SPACE_ID found: {space_id_startup}")
367
+ print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
368
+ print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main")
369
+ else:
370
+ print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.")
371
+
372
+ print("-"*(60 + len(" App Starting ")) + "\n")
373
+
374
+ print("Launching Gradio Interface for GAIA Agent Evaluation...")
375
+ demo.launch(debug=True, share=False)
configuration.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Define the configurable parameters for the agent supervisor system."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field, fields
6
+ from typing import Annotated
7
+
8
+ from langchain_core.runnables import ensure_config
9
+ from langgraph.config import get_config
10
+
11
+ from react_agent import prompts
12
+
13
+
14
+ @dataclass(kw_only=True)
15
+ class Configuration:
16
+ """The configuration for the agent supervisor system."""
17
+
18
+ # Supervisor configuration
19
+ supervisor_prompt: str = field(
20
+ default=prompts.SUPERVISOR_PROMPT,
21
+ metadata={
22
+ "description": "The system prompt for the supervisor agent. "
23
+ "This prompt guides how the supervisor delegates tasks to worker agents."
24
+ },
25
+ )
26
+
27
+ # Planner configuration
28
+ planner_prompt: str = field(
29
+ default=prompts.PLANNER_PROMPT,
30
+ metadata={
31
+ "description": "The system prompt for the planner agent. "
32
+ "This prompt guides how the planner creates structured plans."
33
+ },
34
+ )
35
+
36
+ # Critic configuration
37
+ critic_prompt: str = field(
38
+ default=prompts.CRITIC_PROMPT,
39
+ metadata={
40
+ "description": "The system prompt for the critic agent. "
41
+ "This prompt guides how the critic evaluates answers."
42
+ },
43
+ )
44
+
45
+ # Worker agents configuration
46
+ researcher_prompt: str = field(
47
+ default=prompts.RESEARCHER_PROMPT,
48
+ metadata={
49
+ "description": "The system prompt for the researcher agent. "
50
+ "This prompt defines the researcher's capabilities and limitations."
51
+ },
52
+ )
53
+
54
+ coder_prompt: str = field(
55
+ default=prompts.CODER_PROMPT,
56
+ metadata={
57
+ "description": "The system prompt for the coder agent. "
58
+ "This prompt defines the coder's capabilities and approach to programming tasks."
59
+ },
60
+ )
61
+
62
+ # Shared configuration
63
+ system_prompt: str = field(
64
+ default=prompts.SYSTEM_PROMPT,
65
+ metadata={
66
+ "description": "Legacy system prompt for backward compatibility. "
67
+ "This prompt is used when running the agent in non-supervisor mode."
68
+ },
69
+ )
70
+
71
+ # LLM Configuration - Default model for backward compatibility
72
+ model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
73
+ default="openai/gpt-4o-mini",
74
+ metadata={
75
+ "description": "The default large language model used by the agents (provider/model_name)."
76
+ },
77
+ )
78
+
79
+ # Model for the researcher (information gathering) - use powerful model
80
+ researcher_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
81
+ default="openai/gpt-4o-mini",
82
+ metadata={
83
+ "description": "The model used by the researcher agent for gathering information (provider/model_name)."
84
+ },
85
+ )
86
+
87
+ # Model for the coder (code execution) - use Claude Sonnet
88
+ coder_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
89
+ default="anthropic/claude-3-5-sonnet-20240620",
90
+ metadata={
91
+ "description": "The model used by the coder agent for programming tasks (provider/model_name)."
92
+ },
93
+ )
94
+
95
+ # Model for lightweight reasoning tasks (planner, supervisor, critic)
96
+ planner_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
97
+ default="google_genai/gemini-1.5-flash",
98
+ metadata={
99
+ "description": "The lightweight reasoning model used by the planner, supervisor, and critic (provider/model_name)."
100
+ },
101
+ )
102
+
103
+ # Same model used for supervisor and critic (points to planner_model)
104
+ supervisor_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
105
+ default="google_genai/gemini-1.5-flash",
106
+ metadata={
107
+ "description": "The model used by the supervisor for routing (provider/model_name)."
108
+ },
109
+ )
110
+
111
+ critic_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
112
+ default="openai/gpt-4o-mini",
113
+ metadata={
114
+ "description": "The model used by the critic for evaluation (provider/model_name)."
115
+ },
116
+ )
117
+
118
+ # Model for final answer generation - using Claude for precise formatting
119
+ final_answer_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field(
120
+ default="anthropic/claude-3-5-sonnet-20240620",
121
+ metadata={
122
+ "description": "The model used for generating the final answers in GAIA benchmark format (provider/model_name)."
123
+ },
124
+ )
125
+
126
+ # Tool Configuration
127
+ max_search_results: int = field(
128
+ default=5,
129
+ metadata={
130
+ "description": "The maximum number of search results to return."
131
+ },
132
+ )
133
+
134
+ # Execution Configuration
135
+ recursion_limit: int = field(
136
+ default=50,
137
+ metadata={
138
+ "description": "Maximum number of recursion steps allowed in the LangGraph execution."
139
+ },
140
+ )
141
+
142
+ max_iterations: int = field(
143
+ default=12,
144
+ metadata={
145
+ "description": "Maximum number of iterations allowed to prevent infinite loops."
146
+ },
147
+ )
148
+
149
+ allow_agent_to_extract_answers: bool = field(
150
+ default=True,
151
+ metadata={
152
+ "description": "Whether to allow the agent to extract answers from context when formatting fails."
153
+ },
154
+ )
155
+
156
+ @classmethod
157
+ def from_context(cls) -> Configuration:
158
+ """Create a Configuration instance from a RunnableConfig object."""
159
+ try:
160
+ config = get_config()
161
+ except RuntimeError:
162
+ config = None
163
+ config = ensure_config(config)
164
+ configurable = config.get("configurable") or {}
165
+ _fields = {f.name for f in fields(cls) if f.init}
166
+ return cls(**{k: v for k, v in configurable.items() if k in _fields})
graph.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Define an Agent Supervisor graph with specialized worker agents.
2
+
3
+ The supervisor routes tasks to specialized agents based on the query type.
4
+ """
5
+
6
+ from typing import Dict, List, Literal, Optional, Union, Type, cast
7
+
8
+ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
9
+ from langchain_core.prompts import ChatPromptTemplate
10
+ from langgraph.graph import StateGraph, START, END
11
+ # Import adjusted for compatibility
12
+ from langgraph.prebuilt import create_react_agent # Try original import path first
13
+ from langgraph.types import Command
14
+
15
+ from react_agent.configuration import Configuration
16
+ from react_agent.state import WORKERS, MEMBERS, ROUTING, VERDICTS, State, Router, Plan, PlanStep, CriticVerdict
17
+ from react_agent.tools import TOOLS, tavily_tool, python_repl_tool
18
+ from react_agent.utils import load_chat_model, format_system_prompt, get_message_text
19
+ from react_agent import prompts
20
+ from react_agent.supervisor_node import supervisor_node
21
+
22
+
23
+ # Compile-time type definitions
24
+ SupervisorDestinations = Literal["planner", "critic", "researcher", "coder", "final_answer", "__end__"]
25
+ WorkerDestination = Literal["supervisor"]
26
+
27
+
28
+ # Helper function to check if a message is from a user
29
+ def is_user_message(message):
30
+ """Check if a message is from a user regardless of message format."""
31
+ if isinstance(message, dict):
32
+ return message.get("role") == "user"
33
+ elif isinstance(message, HumanMessage):
34
+ return True
35
+ return False
36
+
37
+
38
+ # Helper function to get message content
39
+ def get_message_content(message):
40
+ """Extract content from a message regardless of format."""
41
+ if isinstance(message, dict):
42
+ return message.get("content", "")
43
+ elif hasattr(message, "content"):
44
+ return message.content
45
+ return ""
46
+
47
+
48
+ # --- Planner node ---------------------------------------------------------
49
+
50
+ def planner_node(state: State) -> Command[WorkerDestination]:
51
+ """Planning LLM that creates a step-by-step execution plan.
52
+
53
+ Args:
54
+ state: The current state with messages
55
+
56
+ Returns:
57
+ Command to update the state with a plan
58
+ """
59
+ configuration = Configuration.from_context()
60
+ # Use the specific planner model
61
+ planner_llm = load_chat_model(configuration.planner_model)
62
+
63
+ # Track steps
64
+ steps_taken = state.get("steps_taken", 0)
65
+ steps_taken += 1
66
+
67
+ # Get the original user question (the latest user message)
68
+ user_messages = [m for m in state["messages"] if is_user_message(m)]
69
+ original_question = get_message_content(user_messages[-1]) if user_messages else "Help me"
70
+
71
+ # Create a chat prompt template with proper formatting
72
+ planner_prompt_template = ChatPromptTemplate.from_messages([
73
+ ("system", prompts.PLANNER_PROMPT),
74
+ ("user", "{question}")
75
+ ])
76
+
77
+ # Format the prompt with the necessary variables
78
+ formatted_messages = planner_prompt_template.format_messages(
79
+ question=original_question,
80
+ system_time=format_system_prompt("{system_time}"),
81
+ workers=", ".join(WORKERS),
82
+ worker_options=", ".join([f'"{w}"' for w in WORKERS]),
83
+ example_worker_1=WORKERS[0] if WORKERS else "researcher",
84
+ example_worker_2=WORKERS[1] if len(WORKERS) > 1 else "coder"
85
+ )
86
+
87
+ # Get structured output from the planner model
88
+ plan = planner_llm.with_structured_output(Plan).invoke(formatted_messages)
89
+
90
+ # Return with updated state
91
+ return Command(
92
+ goto="supervisor",
93
+ update={
94
+ "plan": plan,
95
+ "current_step_index": 0,
96
+ # Add a message to show the plan was created
97
+ "messages": [
98
+ HumanMessage(
99
+ content=f"Created plan with {len(plan['steps'])} steps",
100
+ name="planner"
101
+ )
102
+ ],
103
+ "steps_taken": steps_taken
104
+ }
105
+ )
106
+
107
+
108
+ # --- Final Answer node -----------------------------------------------------
109
+
110
+ def final_answer_node(state: State) -> Command[Literal["__end__"]]:
111
+ """Generate a final answer based on gathered information.
112
+
113
+ Args:
114
+ state: The current state with messages and context
115
+
116
+ Returns:
117
+ Command with final answer
118
+ """
119
+ configuration = Configuration.from_context()
120
+
121
+ # Track steps
122
+ steps_taken = state.get("steps_taken", 0)
123
+ steps_taken += 1
124
+
125
+ # Check if we've exhausted retries and already have a draft answer
126
+ retry_exhausted = state.get("retry_exhausted", False)
127
+ draft_answer = state.get("draft_answer")
128
+
129
+ # Variable to store the final answer
130
+ gaia_answer = ""
131
+
132
+ if retry_exhausted and draft_answer and draft_answer.startswith("FINAL ANSWER:"):
133
+ # If supervisor already provided a properly formatted answer after exhausting retries,
134
+ # use it directly without calling the model again
135
+ import re
136
+ final_answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", draft_answer, re.IGNORECASE)
137
+ if final_answer_match:
138
+ gaia_answer = final_answer_match.group(1).strip()
139
+ else:
140
+ gaia_answer = "unknown"
141
+ else:
142
+ # Use the specific final answer model
143
+ final_llm = load_chat_model(configuration.final_answer_model)
144
+
145
+ # Get the original user question (the latest user message)
146
+ user_messages = [m for m in state["messages"] if is_user_message(m)]
147
+ original_question = get_message_content(user_messages[-1]) if user_messages else "Help me"
148
+
149
+ # Check if we already have a draft answer from supervisor
150
+ if draft_answer and draft_answer.startswith("FINAL ANSWER:"):
151
+ # If supervisor already provided a properly formatted answer, use it directly
152
+ raw_answer = draft_answer
153
+ else:
154
+ # Get the context and worker results
155
+ context = state.get("context", {})
156
+ worker_results = state.get("worker_results", {})
157
+
158
+ # Compose a prompt for the final answer using the GAIA-specific format
159
+ final_prompt = ChatPromptTemplate.from_messages([
160
+ ("system", prompts.FINAL_ANSWER_PROMPT),
161
+ ("user", prompts.FINAL_ANSWER_USER_PROMPT)
162
+ ])
163
+
164
+ # Format the context information more effectively
165
+ context_list = []
166
+ # First include researcher context as it provides background
167
+ if "researcher" in context:
168
+ context_list.append(f"Research information: {context['researcher']}")
169
+
170
+ # Then include coder results which are typically calculations
171
+ if "coder" in context:
172
+ context_list.append(f"Calculation results: {context['coder']}")
173
+
174
+ # Add any other workers
175
+ for worker, content in context.items():
176
+ if worker not in ["researcher", "coder"]:
177
+ context_list.append(f"{worker.capitalize()}: {content}")
178
+
179
+ # Get the final answer
180
+ formatted_messages = final_prompt.format_messages(
181
+ question=original_question,
182
+ context="\n\n".join(context_list)
183
+ )
184
+
185
+ raw_answer = final_llm.invoke(formatted_messages).content
186
+
187
+ # Extract the answer in GAIA format: "FINAL ANSWER: [x]"
188
+ import re
189
+ gaia_answer = raw_answer
190
+ final_answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", raw_answer, re.IGNORECASE)
191
+ if final_answer_match:
192
+ gaia_answer = final_answer_match.group(1).strip()
193
+
194
+ # Ensure answer is properly formatted - if we don't have a valid answer
195
+ # but have sufficient context, try to extract directly
196
+ if configuration.allow_agent_to_extract_answers and (not gaia_answer or gaia_answer.lower() in ["unknown", "insufficient information"]):
197
+ context = state.get("context", {})
198
+ from react_agent.supervisor_node import extract_best_answer_from_context
199
+ extracted_answer = extract_best_answer_from_context(context)
200
+ if extracted_answer != "unknown":
201
+ gaia_answer = extracted_answer
202
+
203
+ # Set status to "final_answer_generated" to indicate we're done
204
+ return Command(
205
+ goto=END,
206
+ update={
207
+ "messages": [
208
+ AIMessage(
209
+ content=f"FINAL ANSWER: {gaia_answer}",
210
+ name="supervisor"
211
+ )
212
+ ],
213
+ "next": "FINISH", # Update next to indicate we're done
214
+ "gaia_answer": gaia_answer, # Store answer in GAIA-compatible format
215
+ "submitted_answer": gaia_answer, # Store as submitted_answer for GAIA benchmark
216
+ "status": "final_answer_generated", # Add status to indicate we're complete
217
+ "steps_taken": steps_taken
218
+ }
219
+ )
220
+
221
+
222
+ # --- Critic node ----------------------------------------------------------
223
+
224
+ def critic_node(state: State) -> Command[Union[WorkerDestination, SupervisorDestinations]]:
225
+ """Critic that evaluates if the answer fully satisfies the request.
226
+
227
+ Args:
228
+ state: The current state with messages and draft answer
229
+
230
+ Returns:
231
+ Command with evaluation verdict
232
+ """
233
+ configuration = Configuration.from_context()
234
+ # Use the specific critic model
235
+ critic_llm = load_chat_model(configuration.critic_model)
236
+
237
+ # Track steps
238
+ steps_taken = state.get("steps_taken", 0)
239
+ steps_taken += 1
240
+
241
+ # Get the original user question (the latest user message)
242
+ user_messages = [m for m in state["messages"] if is_user_message(m)]
243
+ original_question = get_message_content(user_messages[-1]) if user_messages else "Help me"
244
+
245
+ # Get the draft answer
246
+ draft_answer = state.get("draft_answer", "No answer provided.")
247
+
248
+ # Create a chat prompt template with proper formatting
249
+ critic_prompt_template = ChatPromptTemplate.from_messages([
250
+ ("system", prompts.CRITIC_PROMPT),
251
+ ("user", prompts.CRITIC_USER_PROMPT)
252
+ ])
253
+
254
+ # Format the prompt with the necessary variables
255
+ formatted_messages = critic_prompt_template.format_messages(
256
+ question=original_question,
257
+ answer=draft_answer,
258
+ system_time=format_system_prompt("{system_time}"),
259
+ correct_verdict=VERDICTS[0] if VERDICTS else "CORRECT",
260
+ retry_verdict=VERDICTS[1] if len(VERDICTS) > 1 else "RETRY"
261
+ )
262
+
263
+ # Get structured output from the critic model
264
+ verdict = critic_llm.with_structured_output(CriticVerdict).invoke(formatted_messages)
265
+
266
+ # Add a message about the verdict
267
+ if verdict["verdict"] == VERDICTS[0]: # CORRECT
268
+ verdict_message = "Answer is complete, accurate, and properly formatted for GAIA."
269
+ goto = "final_answer" # Go to final answer node if correct
270
+ else:
271
+ verdict_message = f"Answer needs improvement. Reason: {verdict.get('reason', 'Unknown')}"
272
+ goto = "supervisor"
273
+
274
+ # Return with updated state
275
+ return Command(
276
+ goto=goto,
277
+ update={
278
+ "critic_verdict": verdict,
279
+ "messages": [
280
+ HumanMessage(
281
+ content=verdict_message,
282
+ name="critic"
283
+ )
284
+ ],
285
+ "steps_taken": steps_taken
286
+ }
287
+ )
288
+
289
+
290
+ # --- Worker agent factory -------------------------------------------------
291
+
292
+ def create_worker_node(worker_type: str):
293
+ """Factory function to create a worker node of the specified type.
294
+
295
+ Args:
296
+ worker_type: The type of worker to create (must be in WORKERS)
297
+
298
+ Returns:
299
+ A function that processes requests for the specified worker type
300
+ """
301
+ if worker_type not in WORKERS:
302
+ raise ValueError(f"Unknown worker type: {worker_type}")
303
+
304
+ configuration = Configuration.from_context()
305
+
306
+ # Select the appropriate model for each worker type
307
+ if worker_type == "researcher":
308
+ llm = load_chat_model(configuration.researcher_model)
309
+ worker_prompt = prompts.RESEARCHER_PROMPT
310
+ worker_tools = [tavily_tool]
311
+ elif worker_type == "coder":
312
+ llm = load_chat_model(configuration.coder_model)
313
+ worker_prompt = prompts.CODER_PROMPT
314
+ worker_tools = [python_repl_tool]
315
+ else:
316
+ # Default case
317
+ llm = load_chat_model(configuration.model)
318
+ worker_prompt = getattr(prompts, f"{worker_type.upper()}_PROMPT", prompts.SYSTEM_PROMPT)
319
+ worker_tools = TOOLS
320
+
321
+ # Create the agent
322
+ worker_agent = create_react_agent(
323
+ llm,
324
+ tools=worker_tools,
325
+ prompt=format_system_prompt(worker_prompt)
326
+ )
327
+
328
+ # Define node function
329
+ def worker_node(state: State) -> Command[WorkerDestination]:
330
+ """Process requests using the specified worker.
331
+
332
+ Args:
333
+ state: The current conversation state
334
+
335
+ Returns:
336
+ Command to return to supervisor with results
337
+ """
338
+ # Track steps
339
+ steps_taken = state.get("steps_taken", 0)
340
+ steps_taken += 1
341
+
342
+ # Get the last message from the supervisor, which contains our task
343
+ task_message = None
344
+ if state.get("messages"):
345
+ for msg in reversed(state["messages"]):
346
+ if hasattr(msg, "name") and msg.name == "supervisor":
347
+ task_message = msg
348
+ break
349
+
350
+ if not task_message:
351
+ return Command(
352
+ goto="supervisor",
353
+ update={
354
+ "messages": [
355
+ HumanMessage(
356
+ content=f"Error: No task message found for {worker_type}",
357
+ name=worker_type
358
+ )
359
+ ],
360
+ "steps_taken": steps_taken
361
+ }
362
+ )
363
+
364
+ # Create a new state with just the relevant messages for this worker
365
+ # This prevents confusion from unrelated parts of the conversation
366
+ agent_input = {
367
+ "messages": [
368
+ # Include the first user message for context
369
+ state["messages"][0] if state["messages"] else HumanMessage(content="Help me"),
370
+ # Include the task message
371
+ task_message
372
+ ]
373
+ }
374
+
375
+ # Invoke the agent with the clean input
376
+ result = worker_agent.invoke(agent_input)
377
+
378
+ # Extract the result from the agent response
379
+ result_content = extract_worker_result(worker_type, result, state)
380
+
381
+ # Store the worker's result in shared context
382
+ context_update = state.get("context", {}).copy()
383
+ context_update[worker_type] = result_content
384
+
385
+ # Store in worker_results history
386
+ worker_results = state.get("worker_results", {}).copy()
387
+ if worker_type not in worker_results:
388
+ worker_results[worker_type] = []
389
+ worker_results[worker_type].append(result_content)
390
+
391
+ # Increment the step index after worker completes
392
+ current_step_index = state.get("current_step_index", 0)
393
+
394
+ return Command(
395
+ update={
396
+ "messages": [
397
+ HumanMessage(content=result_content, name=worker_type)
398
+ ],
399
+ "current_step_index": current_step_index + 1,
400
+ "context": context_update,
401
+ "worker_results": worker_results,
402
+ "steps_taken": steps_taken
403
+ },
404
+ goto="supervisor",
405
+ )
406
+
407
+ return worker_node
408
+
409
+
410
+ def extract_worker_result(worker_type: str, result: dict, state: State) -> str:
411
+ """Extract a clean, useful result from the worker's output.
412
+
413
+ This handles different response formats from different worker types.
414
+
415
+ Args:
416
+ worker_type: The type of worker (researcher or coder)
417
+ result: The raw result from the worker agent
418
+ state: The current state for context
419
+
420
+ Returns:
421
+ A cleaned string with the relevant result information
422
+ """
423
+ # Handle empty results
424
+ if not result or "messages" not in result or not result["messages"]:
425
+ return f"No output from {worker_type}"
426
+
427
+ # Get the last message from the agent
428
+ last_message = result["messages"][-1]
429
+
430
+ # Default to extracting content directly
431
+ if hasattr(last_message, "content") and last_message.content:
432
+ result_content = last_message.content
433
+ else:
434
+ result_content = f"No content from {worker_type}"
435
+
436
+ # Special handling based on worker type
437
+ if worker_type == "coder":
438
+ # For coder outputs, extract the actual result values from code execution
439
+ if "```" in result_content:
440
+ # Try to extract stdout from code execution
441
+ import re
442
+ stdout_match = re.search(r"Stdout:\s*(.*?)(?:\n\n|$)", result_content, re.DOTALL)
443
+ if stdout_match:
444
+ # Extract the actual execution output, not just the code
445
+ execution_result = stdout_match.group(1).strip()
446
+ if execution_result:
447
+ # Check if this is just a simple number result
448
+ if re.match(r"^\d+(\.\d+)?$", execution_result):
449
+ return execution_result
450
+ else:
451
+ return f"Code executed with result: {execution_result}"
452
+
453
+ # If we couldn't find stdout, try to extract output in a different way
454
+ # Look for "Result:" or similar indicators
455
+ result_match = re.search(r"(?:Result|Output|Answer):\s*(.*?)(?:\n\n|$)", result_content, re.DOTALL)
456
+ if result_match:
457
+ return result_match.group(1).strip()
458
+
459
+ elif worker_type == "researcher":
460
+ # For researcher outputs, keep the full detailed response
461
+ # but ensure it's well-formatted
462
+ if len(result_content) > 800:
463
+ # If too long, try to extract key sections
464
+ # Look for summary or conclusion sections
465
+ import re
466
+ summary_match = re.search(r"(?:Summary|Conclusion|To summarize|In summary):(.*?)(?:\n\n|$)",
467
+ result_content, re.IGNORECASE | re.DOTALL)
468
+ if summary_match:
469
+ return summary_match.group(1).strip()
470
+
471
+ # If no special handling was triggered, return the content as is
472
+ return result_content
473
+
474
+
475
+ # --- Graph assembly -------------------------------------------------------
476
+
477
+ def create_agent_supervisor_graph() -> StateGraph:
478
+ """Create the agent supervisor graph with all nodes and edges.
479
+
480
+ Returns:
481
+ Compiled StateGraph ready for execution
482
+ """
483
+ # Initialize the graph with our State type
484
+ builder = StateGraph(State)
485
+
486
+ # Add control nodes
487
+ builder.add_node("planner", planner_node)
488
+ builder.add_node("supervisor", supervisor_node)
489
+ builder.add_node("critic", critic_node)
490
+ builder.add_node("final_answer", final_answer_node)
491
+
492
+ # Add worker nodes dynamically based on WORKERS list
493
+ for worker_type in WORKERS:
494
+ builder.add_node(worker_type, create_worker_node(worker_type))
495
+
496
+ # Define the workflow
497
+ builder.add_edge(START, "supervisor")
498
+ builder.add_edge("planner", "supervisor")
499
+ builder.add_edge("critic", "supervisor")
500
+ builder.add_edge("critic", "final_answer") # Add edge from critic to final_answer
501
+ builder.add_edge("final_answer", END) # Final answer node goes to END
502
+ builder.add_edge("supervisor", END) # Allow the supervisor to end the workflow
503
+
504
+ # Connect all workers to supervisor
505
+ for worker_type in WORKERS:
506
+ builder.add_edge(worker_type, "supervisor")
507
+
508
+ # Return the builder, not a compiled graph
509
+ # This allows the caller to compile with a checkpointer
510
+ return builder
511
+
512
+
513
+ # --- Graph instantiation (with flexible checkpointing) -----------------------------
514
+
515
+ def get_compiled_graph(checkpointer=None):
516
+ """Get a compiled graph with optional checkpointer.
517
+
518
+ Args:
519
+ checkpointer: Optional checkpointer for persistence
520
+
521
+ Returns:
522
+ Compiled StateGraph ready for execution
523
+ """
524
+ # Get configuration
525
+ configuration = Configuration.from_context()
526
+
527
+ builder = create_agent_supervisor_graph()
528
+
529
+ # Define termination condition function to prevent loops
530
+ def should_end(state):
531
+ """Determine if the graph should terminate."""
532
+ # End if status is set to final_answer_generated
533
+ if state.get("status") == "final_answer_generated":
534
+ return True
535
+
536
+ # End if retry_exhausted flag is set and we've gone through final_answer
537
+ if state.get("retry_exhausted") and state.get("gaia_answer"):
538
+ return True
539
+
540
+ # End if we've hit maximum recursion limit defined by LangGraph
541
+ steps_taken = state.get("steps_taken", 0)
542
+ if steps_taken >= configuration.recursion_limit - 5: # Leave buffer
543
+ return True
544
+
545
+ return False
546
+
547
+ # Define step counter for tracking step count
548
+ def count_steps(state):
549
+ """Count steps to prevent infinite loops."""
550
+ steps_taken = state.get("steps_taken", 0)
551
+ return {"steps_taken": steps_taken + 1}
552
+
553
+ # Compile the graph (don't use add_state_transform which isn't available)
554
+ if checkpointer:
555
+ graph = builder.compile(
556
+ checkpointer=checkpointer,
557
+ name="Structured Reasoning Loop"
558
+ )
559
+ else:
560
+ graph = builder.compile(
561
+ name="Structured Reasoning Loop"
562
+ )
563
+
564
+ # Configure the graph with recursion limit and max iterations
565
+ graph = graph.with_config({
566
+ "recursion_limit": configuration.recursion_limit,
567
+ "max_iterations": configuration.max_iterations
568
+ })
569
+
570
+ return graph
571
+
572
+
573
+ # Initialize a default non-checkpointed graph (for backward compatibility)
574
+ graph = get_compiled_graph()
prompts.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """System prompts used by the agent supervisor and worker agents."""
2
+
3
+ from react_agent.state import WORKERS, VERDICTS
4
+
5
+ # --- Supervisor prompt -----------------------------------------------------
6
+
7
+ SUPERVISOR_PROMPT = """You are a supervisor tasked with managing a conversation between the \
8
+ following workers: {workers}. Given the following user request, \
9
+ respond with the worker to act next. Each worker will perform a \
10
+ task and respond with their results and status. When finished, \
11
+ respond with FINISH.
12
+
13
+ System time: {system_time}"""
14
+
15
+ # --- Planner prompt -------------------------------------------------------
16
+
17
+ PLANNER_PROMPT = """**Role**: You are a Planner node in a LangGraph supervisor workflow
18
+ **Goal**: Given the user's original request, create a concise, focused plan that directly answers the question.
19
+
20
+ Requirements:
21
+ 1. Output only a JSON object with one key `steps`, whose value is an **ordered list** of at least 1 and at most 3 objects.
22
+ Each object has:
23
+ • `worker` – one of: {worker_options}
24
+ • `instruction` – ≤ 20 words telling that worker what to do
25
+
26
+ 2. Your plan MUST:
27
+ • Directly address the user's specific question
28
+ • Include at least one step (never return empty steps)
29
+ • Be focused on finding the exact answer requested, not the process of answering
30
+ • Use researcher for information gathering
31
+ • Use coder for calculations or data analysis if needed
32
+
33
+ 3. Common tasks:
34
+ • For factual questions: use researcher to find the specific fact
35
+ • For calculations: use researcher to find data, then coder to calculate
36
+ • For multiple-part questions: break into steps with the right workers
37
+ • Ensure your last step gets the exact answer in the format requested
38
+
39
+ Example:
40
+ ```
41
+ {{
42
+ "steps": [
43
+ {{"worker": "{example_worker_1}", "instruction": "Find inflation rate in 2023"}},
44
+ {{"worker": "{example_worker_2}", "instruction": "Compute average of 2019–2023 rates"}}
45
+ ]
46
+ }}
47
+ ```
48
+
49
+ System time: {system_time}"""
50
+
51
+ # --- Critic prompt --------------------------------------------------------
52
+
53
+ CRITIC_PROMPT = """**Role**: You are a Critic node specializing in GAIA benchmark format validation
54
+ **Goal**: Strictly check if the answer follows GAIA format requirements
55
+
56
+ Requirements:
57
+ 1. You will check if the answer:
58
+ • Addresses all parts of the user's question correctly
59
+ • Follows the EXACT required GAIA format: "FINAL ANSWER: [concise response]"
60
+ • Contains ONLY the essential information in the [concise response]:
61
+ - A single number (no commas, no units like $ or % unless specified)
62
+ - A single word or very short phrase
63
+ - A comma-separated list of numbers or strings
64
+ • Has NO explanations, reasoning, or extra text
65
+ • For strings: no articles or abbreviations
66
+ • For numbers: digits only without commas
67
+
68
+ 2. If the answer is CORRECT, respond ONLY with this exact JSON:
69
+ • `{{"verdict":"{correct_verdict}"}}`
70
+
71
+ 3. If ANY requirement is NOT MET, respond with this JSON including a SPECIFIC reason:
72
+ • `{{"verdict":"{retry_verdict}","reason":"<specific format issue>"}}`
73
+ • IMPORTANT: You MUST provide a substantive reason that clearly explains what's wrong
74
+ • NEVER leave the reason empty or only containing quotes
75
+
76
+ 4. Common reason examples:
77
+ • "Answer not formatted as 'FINAL ANSWER: [response]'"
78
+ • "Answer contains explanations instead of just the concise response"
79
+ • "Answer does not address the question about [specific topic]"
80
+ • "Answer contains units when it should just be a number"
81
+
82
+ DO NOT include any text before or after the JSON. Your complete response must be valid JSON that can be parsed.
83
+
84
+ System time: {system_time}"""
85
+
86
+ # --- Critic user prompt ---------------------------------------------------
87
+
88
+ CRITIC_USER_PROMPT = """Original question: {question}
89
+
90
+ Draft answer: {answer}
91
+
92
+ Check if the draft answer follows GAIA format requirements:
93
+ 1. Format must be exactly "FINAL ANSWER: [concise response]"
94
+ 2. [concise response] must ONLY be:
95
+ - A single number (no commas or units unless specified)
96
+ - A single word or very short phrase
97
+ - A comma-separated list of numbers or strings
98
+ 3. NO explanations or additional text is allowed
99
+ 4. Strings should not have articles or abbreviations
100
+ 5. Numbers should be in digits without commas
101
+
102
+ Does the answer meet these requirements and correctly answer the question?"""
103
+
104
+ # --- Final Answer format for GAIA benchmark -------------------------------
105
+
106
+ FINAL_ANSWER_PROMPT = """You are a response formatter for a GAIA benchmark question.
107
+
108
+ Your only job is to format the final answer in the exact format required: "FINAL ANSWER: [concise response]"
109
+
110
+ Requirements for [concise response]:
111
+ 1. Response must ONLY be one of these formats:
112
+ - A single number (no commas, no units like $ or % unless specified)
113
+ - A single word or very short phrase
114
+ - A comma-separated list of numbers or strings
115
+ 2. DO NOT include any explanations, reasoning, or extra text
116
+ 3. For strings, don't use articles or abbreviations unless specified
117
+ 4. For numbers, write digits (not spelled out) without commas
118
+ 5. The response should be as concise as possible while being correct
119
+
120
+ Original question: {question}
121
+
122
+ Information available:
123
+ {context}
124
+
125
+ After reviewing the information, extract just the essential answer and output ONLY:
126
+ FINAL ANSWER: [your concise response]
127
+ """
128
+
129
+ # --- Final Answer user prompt ---------------------------------------------
130
+
131
+ FINAL_ANSWER_USER_PROMPT = """Original question: {question}
132
+
133
+ Information available:
134
+ {context}
135
+
136
+ Remember to output ONLY 'FINAL ANSWER: [your concise response]' with no explanations."""
137
+
138
+ # --- Worker agent prompts -------------------------------------------------
139
+
140
+ RESEARCHER_PROMPT = """You are a research specialist focused on finding information and providing context.
141
+
142
+ Your key responsibilities:
143
+ 1. Search for accurate, up-to-date information on any topic
144
+ 2. Provide factual knowledge about products, concepts, and terminology
145
+ 3. Explain real-world contexts and background information
146
+ 4. Identify relevant parameters and variables needed for calculations
147
+ 5. Present information clearly with proper citations
148
+
149
+ DO NOT perform complex calculations or coding tasks - these will be handled by the coder agent.
150
+ You MAY provide simple arithmetic or basic formulas to illustrate concepts.
151
+
152
+ Always return information in a structured, organized format that will be useful for the next steps.
153
+
154
+ System time: {system_time}
155
+ """
156
+
157
+ CODER_PROMPT = """You are a computational specialist focused on calculations, coding, and data analysis.
158
+
159
+ Your key responsibilities:
160
+ 1. Write and execute Python code for calculations and data manipulation
161
+ 2. Perform precise numerical analyses based on inputs from the researcher
162
+ 3. Format results clearly with appropriate units and precision
163
+ 4. Use markdown to structure your response with headings and bullet points
164
+ 5. Verify calculations through multiple methods when possible
165
+
166
+ Important:
167
+ 1. Always include both your calculation process AND final result values
168
+ 2. Always clearly state your assumptions when making calculations
169
+ 3. Format numerical results with appropriate precision and units
170
+ 4. When receiving data from the researcher, acknowledge and build upon it directly
171
+ 5. If calculation involves multiple steps or cases, organize them with headings
172
+
173
+ System time: {system_time}
174
+ """
175
+
176
+ # --- Legacy system prompt (kept for backward compatibility) ---------------
177
+
178
+ SYSTEM_PROMPT = """You are a helpful AI assistant.
179
+
180
+ System time: {system_time}"""
state.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Define the state structures for the agent supervisor."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, List, Literal, Optional, Sequence, Any
6
+
7
+ from langchain_core.messages import AnyMessage
8
+ from langgraph.graph import MessagesState, add_messages
9
+ from typing_extensions import TypedDict, Annotated
10
+
11
+
12
+ # --- Constants and shared definitions ---------------------------------------
13
+
14
+ # Define worker types (specialized agents that perform tasks)
15
+ WORKERS = ["researcher", "coder"]
16
+
17
+ # Define all member types (including control nodes)
18
+ MEMBERS = WORKERS + ["planner", "critic", "supervisor"]
19
+
20
+ # Define status/routing options
21
+ VERDICTS = ["CORRECT", "RETRY"]
22
+ ROUTING = ["FINISH"] + WORKERS
23
+ OPTIONS = ROUTING + VERDICTS
24
+
25
+
26
+ # --- Router for supervisor decisions ---------------------------------------
27
+
28
+ class Router(TypedDict):
29
+ """Determines which worker to route to next or if the task is complete.
30
+
31
+ The supervisor returns this structure to navigate the workflow.
32
+ Valid values are defined in the ROUTING list.
33
+ """
34
+ next: Literal[*ROUTING]
35
+
36
+
37
+ # --- Plan structure for the Planner node -----------------------------------
38
+
39
+ class PlanStep(TypedDict):
40
+ """A single step in the plan created by the Planner."""
41
+ worker: Literal[*WORKERS]
42
+ instruction: str
43
+
44
+
45
+ class Plan(TypedDict):
46
+ """The complete plan produced by the Planner node."""
47
+ steps: List[PlanStep]
48
+
49
+
50
+ # --- Critic verdict structure ----------------------------------------------
51
+
52
+ class CriticVerdict(TypedDict):
53
+ """The verdict from the Critic on whether the answer is satisfactory."""
54
+ verdict: Literal[*VERDICTS]
55
+ reason: Optional[str]
56
+
57
+
58
+ # --- State for the agent supervisor ----------------------------------------
59
+
60
+ class State(MessagesState):
61
+ """State for the agent supervisor workflow.
62
+
63
+ Extends MessagesState which provides message history tracking.
64
+ Adds fields to track routing information, plan, and critic verdict.
65
+ """
66
+ next: str
67
+ plan: Optional[Plan] = None
68
+ current_step_index: Optional[int] = None
69
+ draft_answer: Optional[str] = None
70
+ critic_verdict: Optional[CriticVerdict] = None
71
+ context: Dict[str, Any] = {} # Shared context accessible to all agents
72
+ worker_results: Dict[str, List[str]] = {} # Store results from each worker
supervisor_node.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Supervisor node implementation for the agent supervisor system."""
2
+
3
+ from typing import Dict, List, Literal, Optional, Union, Type, cast
4
+
5
+ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
6
+ from langchain_core.prompts import ChatPromptTemplate
7
+ from langgraph.graph import StateGraph, START, END
8
+ from langgraph.types import Command
9
+
10
+ from react_agent.configuration import Configuration
11
+ from react_agent.state import WORKERS, MEMBERS, ROUTING, VERDICTS, State, Router
12
+ from react_agent.utils import load_chat_model, format_system_prompt, get_message_text
13
+ from react_agent import prompts
14
+
15
+
16
+ # Compile-time type definitions
17
+ SupervisorDestinations = Literal["planner", "critic", "researcher", "coder", "final_answer", "__end__"]
18
+
19
+
20
+ def supervisor_node(state: State) -> Command[SupervisorDestinations]:
21
+ """Supervising LLM that decides which specialized agent should act next.
22
+
23
+ Args:
24
+ state: The current state with messages
25
+
26
+ Returns:
27
+ Command with routing information
28
+ """
29
+ # Get configuration to use supervisor_model
30
+ configuration = Configuration.from_context()
31
+
32
+ # Track steps to prevent infinite loops
33
+ steps_taken = state.get("steps_taken", 0)
34
+ steps_taken += 1
35
+ state_updates = {"steps_taken": steps_taken}
36
+
37
+ # Check if we've hit our step limit
38
+ if steps_taken >= configuration.recursion_limit - 5: # Buffer of 5 steps
39
+ # Extract the best answer we have from context if possible
40
+ context = state.get("context", {})
41
+ answer = extract_best_answer_from_context(context)
42
+
43
+ return Command(
44
+ goto="final_answer",
45
+ update={
46
+ "messages": [
47
+ HumanMessage(
48
+ content=f"Maximum steps ({steps_taken}) reached. Extracting best answer from available information.",
49
+ name="supervisor"
50
+ )
51
+ ],
52
+ "draft_answer": f"FINAL ANSWER: {answer}",
53
+ "retry_exhausted": True, # Flag to indicate we've exhausted retries
54
+ "steps_taken": steps_taken
55
+ }
56
+ )
57
+
58
+ # Safety check - prevent infinite loops by forcing termination after too many retry steps
59
+ retry_count = state.get("retry_count", 0)
60
+ max_retries = 2 # Maximum number of allowed retries
61
+
62
+ if retry_count > max_retries:
63
+ # Extract the best answer we have from context if possible
64
+ context = state.get("context", {})
65
+ answer = extract_best_answer_from_context(context)
66
+
67
+ return Command(
68
+ goto="final_answer",
69
+ update={
70
+ "messages": [
71
+ HumanMessage(
72
+ content=f"Maximum retries ({max_retries}) reached. Extracting best answer from available information.",
73
+ name="supervisor"
74
+ )
75
+ ],
76
+ "draft_answer": f"FINAL ANSWER: {answer}",
77
+ "retry_exhausted": True, # Flag to indicate we've exhausted retries
78
+ "steps_taken": steps_taken
79
+ }
80
+ )
81
+
82
+ # Check if we need a plan
83
+ if not state.get("plan"):
84
+ return Command(
85
+ goto="planner",
86
+ update={
87
+ **state_updates
88
+ }
89
+ )
90
+
91
+ # Validate that the plan has at least one step
92
+ plan = state.get("plan")
93
+ if not plan.get("steps") or len(plan.get("steps", [])) == 0:
94
+ # Plan has no steps, go back to planner with explicit instructions
95
+ return Command(
96
+ goto="planner",
97
+ update={
98
+ "messages": [
99
+ HumanMessage(
100
+ content="Previous plan had 0 steps. Please create a plan with at least 1 step to solve the user's question.",
101
+ name="supervisor"
102
+ )
103
+ ],
104
+ "plan": None,
105
+ **state_updates
106
+ }
107
+ )
108
+
109
+ # Check if we have a critic verdict that requires replanning
110
+ critic_verdict = state.get("critic_verdict")
111
+ if critic_verdict:
112
+ if critic_verdict.get("verdict") == VERDICTS[0]: # CORRECT
113
+ # Final answer is approved, navigate to the final_answer node
114
+ # This will generate a polished response before ending
115
+ return Command(
116
+ goto="final_answer",
117
+ update={
118
+ "messages": [
119
+ HumanMessage(
120
+ content="Answer approved by critic. Generating final response.",
121
+ name="supervisor"
122
+ )
123
+ ]
124
+ }
125
+ )
126
+ elif critic_verdict.get("verdict") == VERDICTS[1]: # RETRY
127
+ # IMPORTANT: Get the current retry count BEFORE incrementing
128
+ current_retry_count = state.get("retry_count", 0)
129
+
130
+ # Check if we're at the maximum allowed retries
131
+ if current_retry_count >= max_retries:
132
+ # Extract best answer and go to final_answer
133
+ context = state.get("context", {})
134
+ answer = extract_best_answer_from_context(context)
135
+
136
+ return Command(
137
+ goto="final_answer",
138
+ update={
139
+ "messages": [
140
+ HumanMessage(
141
+ content=f"Maximum retries ({max_retries}) reached. Proceeding with best available answer.",
142
+ name="supervisor"
143
+ )
144
+ ],
145
+ "draft_answer": f"FINAL ANSWER: {answer}",
146
+ "retry_exhausted": True # Flag to indicate we've exhausted retries
147
+ }
148
+ )
149
+
150
+ # Reset the plan but KEEP the context from previous iterations
151
+ context = state.get("context", {})
152
+ worker_results = state.get("worker_results", {})
153
+
154
+ # Get the critic's reason for rejection, if any
155
+ reason = critic_verdict.get("reason", "")
156
+ if not reason or reason.strip() == "\"":
157
+ reason = "Answer did not meet format requirements"
158
+
159
+ # Check if this is a formatting issue
160
+ format_issues = [
161
+ "format", "concise", "explanation", "not formatted",
162
+ "instead of just", "contains explanations", "FINAL ANSWER"
163
+ ]
164
+ is_format_issue = any(issue in reason.lower() for issue in format_issues)
165
+
166
+ # If we have enough information but the format is wrong, go directly to final answer
167
+ has_sufficient_info = has_sufficient_information(state)
168
+
169
+ if is_format_issue and has_sufficient_info and current_retry_count >= 0:
170
+ # We have information but formatting is wrong - skip planning and go to final answer
171
+ return Command(
172
+ goto="final_answer",
173
+ update={
174
+ "messages": [
175
+ HumanMessage(
176
+ content="We have sufficient information but formatting issues. Generating properly formatted answer.",
177
+ name="supervisor"
178
+ )
179
+ ],
180
+ "retry_count": current_retry_count + 1 # Still increment retry count
181
+ }
182
+ )
183
+
184
+ # Increment the retry counter
185
+ next_retry_count = current_retry_count + 1
186
+
187
+ return Command(
188
+ goto="planner",
189
+ update={
190
+ "plan": None,
191
+ "current_step_index": None,
192
+ "draft_answer": None,
193
+ "critic_verdict": None,
194
+ # Keep the context and worker_results
195
+ "context": context,
196
+ "worker_results": worker_results,
197
+ # Track retries - IMPORTANT: store the incremented count
198
+ "retry_count": next_retry_count,
199
+ # Add a message about the retry (using the INCREMENTED count)
200
+ "messages": [
201
+ HumanMessage(
202
+ content=f"Retrying with new plan (retry #{next_retry_count}). Reason: {reason}",
203
+ name="supervisor"
204
+ )
205
+ ]
206
+ }
207
+ )
208
+
209
+ # Get the current step from the plan
210
+ plan = state["plan"]
211
+ current_step_index = state.get("current_step_index", 0)
212
+
213
+ # Check if we've completed all steps
214
+ if current_step_index >= len(plan["steps"]):
215
+ # Use context to compile the draft answer
216
+ context = state.get("context", {})
217
+
218
+ # Combine the most recent worker outputs as the draft answer
219
+ worker_results = []
220
+ for worker in WORKERS:
221
+ if worker in context:
222
+ worker_results.append(f"**{worker.title()}**: {context[worker]}")
223
+
224
+ # Compile the draft answer from all worker outputs
225
+ draft_content = "\n\n".join(worker_results)
226
+
227
+ # Send to the critic for evaluation
228
+ return Command(
229
+ goto="critic",
230
+ update={
231
+ "draft_answer": draft_content,
232
+ # Add a message about moving to evaluation
233
+ "messages": [
234
+ HumanMessage(
235
+ content="All steps completed. Evaluating the answer.",
236
+ name="supervisor"
237
+ )
238
+ ]
239
+ }
240
+ )
241
+
242
+ # Get the current step
243
+ current_step = plan["steps"][current_step_index]
244
+ worker = current_step["worker"]
245
+ instruction = current_step["instruction"]
246
+
247
+ # Extract only the most relevant context for the current worker and task
248
+ context_info = ""
249
+ if state.get("context"):
250
+ # Filter context by relevance to the current task
251
+ relevant_context = {}
252
+
253
+ # For the coder, extract numerical data and parameters from researcher
254
+ if worker == "coder" and "researcher" in state["context"]:
255
+ relevant_context["researcher"] = state["context"]["researcher"]
256
+
257
+ # For the researcher, previous coder calculations might be relevant
258
+ if worker == "researcher" and "coder" in state["context"]:
259
+ # Only include numerical results from coder, not code snippets
260
+ coder_content = state["context"]["coder"]
261
+ if len(coder_content) < 100: # Only short results are likely just numbers
262
+ relevant_context["coder"] = coder_content
263
+
264
+ # Format the relevant context items
265
+ context_items = []
266
+ for key, value in relevant_context.items():
267
+ # Summarize if value is too long
268
+ if len(value) > 200:
269
+ # Find first sentence or up to 200 chars
270
+ summary = value[:200]
271
+ if '.' in summary:
272
+ summary = summary.split('.')[0] + '.'
273
+ context_items.append(f"Previous {key} found: {summary}...")
274
+ else:
275
+ context_items.append(f"Previous {key} found: {value}")
276
+
277
+ if context_items:
278
+ context_info = "\n\nRelevant context: " + "\n".join(context_items)
279
+
280
+ # Enhance the instruction with context
281
+ enhanced_instruction = f"{instruction}{context_info}"
282
+
283
+ # Add guidance based on worker type
284
+ if worker == "coder":
285
+ enhanced_instruction += "\nProvide both your calculation method AND the final result value."
286
+ elif worker == "researcher":
287
+ enhanced_instruction += "\nFocus on gathering factual information related to the task."
288
+
289
+ # Add the instruction to the messages
290
+ messages_update = [
291
+ HumanMessage(
292
+ content=f"Step {current_step_index + 1}: {enhanced_instruction}",
293
+ name="supervisor"
294
+ )
295
+ ]
296
+
297
+ # Cast worker to appropriate type to satisfy type checking
298
+ worker_destination = cast(SupervisorDestinations, worker)
299
+
300
+ # Move to the appropriate worker
301
+ return Command(
302
+ goto=worker_destination,
303
+ update={
304
+ "messages": messages_update,
305
+ "next": worker, # For backward compatibility
306
+ **state_updates
307
+ }
308
+ )
309
+
310
+ def extract_best_answer_from_context(context):
311
+ """Extract the best available answer from context.
312
+
313
+ This is a generic function to extract answers from any type of question context.
314
+ It progressively tries different strategies to find a suitable answer.
315
+
316
+ Args:
317
+ context: The state context containing worker outputs
318
+
319
+ Returns:
320
+ Best answer found or "unknown" if nothing suitable is found
321
+ """
322
+ answer = "unknown"
323
+
324
+ # First check if the coder already provided a properly formatted answer
325
+ if "coder" in context:
326
+ coder_content = context["coder"]
327
+
328
+ # Look for "FINAL ANSWER: X" pattern in the coder output
329
+ import re
330
+ answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", coder_content, re.IGNORECASE)
331
+ if answer_match:
332
+ return answer_match.group(1).strip()
333
+
334
+ # If no answer in coder output, check researcher content
335
+ if "researcher" in context:
336
+ researcher_content = context["researcher"]
337
+
338
+ # Look for lists in the researcher content (common pattern)
339
+ import re
340
+
341
+ # Look for bulleted list items
342
+ list_items = re.findall(r"[-•*]\s+([^:\n]+)", researcher_content)
343
+ if list_items:
344
+ # Format as comma-separated list
345
+ answer = ",".join(item.strip() for item in list_items)
346
+ return answer
347
+
348
+ # Look for emphasized/bold items which might be key information
349
+ bold_items = re.findall(r"\*\*([^*]+)\*\*", researcher_content)
350
+ if bold_items:
351
+ # Join the important items as a comma-separated list
352
+ processed_items = []
353
+ for item in bold_items:
354
+ # Remove common filler words and clean up the item
355
+ clean_item = re.sub(r'(^|\s)(a|an|the|is|are|was|were|be|been)(\s|$)', ' ', item)
356
+ clean_item = clean_item.strip()
357
+ if clean_item and len(clean_item) < 30: # Only include reasonably short items
358
+ processed_items.append(clean_item)
359
+
360
+ if processed_items:
361
+ answer = ",".join(processed_items)
362
+ return answer
363
+
364
+ # If we still don't have an answer, try to extract common entities
365
+ combined_content = ""
366
+ for worker_type, content in context.items():
367
+ combined_content += " " + content
368
+
369
+ # Look for numbers in the content
370
+ import re
371
+ numbers = re.findall(r'\b(\d+(?:\.\d+)?)\b', combined_content)
372
+ if numbers:
373
+ answer = numbers[0] # Use the first number found
374
+
375
+ return answer
376
+
377
+ def has_sufficient_information(state):
378
+ """Determine if we have enough information to generate a final answer.
379
+
380
+ Args:
381
+ state: The current conversation state
382
+
383
+ Returns:
384
+ Boolean indicating if we have sufficient information
385
+ """
386
+ context = state.get("context", {})
387
+
388
+ # If we have both researcher and coder outputs, we likely have enough info
389
+ if "researcher" in context and "coder" in context:
390
+ return True
391
+
392
+ # If we have a substantial researcher output, that might be enough
393
+ if "researcher" in context and len(context["researcher"]) > 150:
394
+ return True
395
+
396
+ # If we have any worker output that contains lists or formatted data
397
+ for worker, content in context.items():
398
+ if content and (
399
+ "- " in content or # Bullet point
400
+ "•" in content or # Bullet point
401
+ "*" in content or # Emphasis or bullet
402
+ ":" in content # Definition or explanation
403
+ ):
404
+ return True
405
+
406
+ return False
tools.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module provides tools for the agent supervisor.
2
+
3
+ It includes:
4
+ - Web Search: For general web results using Tavily.
5
+ - Python REPL: For executing Python code (Use with caution!).
6
+ """
7
+
8
+ from typing import Annotated, List, Any, Callable, Optional, cast
9
+
10
+ # Core Tools & Utilities
11
+ from langchain_core.tools import tool
12
+
13
+ # Experimental Tools (Use with caution)
14
+ from langchain_experimental.utilities import PythonREPL
15
+
16
+ # Use TavilySearchResults from langchain_community like in the notebook
17
+ from langchain_community.tools.tavily_search import TavilySearchResults
18
+ from react_agent.configuration import Configuration
19
+
20
+
21
+ # Create Tavily tool using configuration from context (more consistent approach)
22
+ def create_tavily_tool():
23
+ """Create the Tavily search tool with configuration from context.
24
+
25
+ Returns:
26
+ Configured TavilySearchResults tool
27
+ """
28
+ configuration = Configuration.from_context()
29
+ return TavilySearchResults(max_results=configuration.max_search_results)
30
+
31
+ # Initialize the tool
32
+ tavily_tool = create_tavily_tool()
33
+
34
+
35
+ # --- Python REPL Tool ---
36
+ # WARNING: Executes arbitrary Python code locally. Be extremely careful
37
+ # about exposing this tool, especially in production environments.
38
+ repl = PythonREPL()
39
+
40
+ @tool
41
+ def python_repl_tool(
42
+ code: Annotated[str, "The python code to execute. Use print(...) to see output."],
43
+ ):
44
+ """Use this to execute python code. If you want to see the output of a value,
45
+ you should print it out with `print(...)`. This is visible to the user."""
46
+ try:
47
+ result = repl.run(code)
48
+ except BaseException as e:
49
+ return f"Failed to execute. Error: {repr(e)}"
50
+ # Filter out potentially sensitive REPL implementation details
51
+ result_str = f"Successfully executed:\n\`\`\`python\n{code}\n\`\`\`\nStdout: {result}"
52
+ return result_str
53
+
54
+
55
+ # --- Tool List ---
56
+
57
+ # The list of tools available to the agent supervisor.
58
+ TOOLS: List[Callable[..., Any]] = [tavily_tool, python_repl_tool]
utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility & helper functions."""
2
+
3
+ import os
4
+ from dotenv import load_dotenv
5
+ from langchain.chat_models import init_chat_model
6
+ from langchain_core.language_models import BaseChatModel
7
+ from langchain_core.messages import BaseMessage
8
+ import asyncio
9
+ from datetime import UTC, datetime
10
+ from react_agent.state import WORKERS, MEMBERS, ROUTING, VERDICTS
11
+
12
+
13
+ # Load environment variables from .env file
14
+ load_dotenv()
15
+
16
+
17
+ def get_message_text(msg: BaseMessage) -> str:
18
+ """Get the text content of a message."""
19
+ content = msg.content
20
+ if isinstance(content, str):
21
+ return content
22
+ elif isinstance(content, dict):
23
+ return content.get("text", "")
24
+ else:
25
+ txts = [c if isinstance(c, str) else (c.get("text") or "") for c in content]
26
+ return "".join(txts).strip()
27
+
28
+
29
+ def format_system_prompt(prompt_template: str) -> str:
30
+ """Format a system prompt template with current system time and available agents.
31
+
32
+ Args:
33
+ prompt_template: The prompt template to format
34
+
35
+ Returns:
36
+ The formatted prompt with system time and agent information
37
+ """
38
+ # Get example workers for templates
39
+ example_worker_1 = WORKERS[0] if WORKERS else "researcher"
40
+ example_worker_2 = WORKERS[1] if len(WORKERS) > 1 else "coder"
41
+
42
+ # Get verdicts for templates
43
+ correct_verdict = VERDICTS[0] if VERDICTS else "CORRECT"
44
+ retry_verdict = VERDICTS[1] if len(VERDICTS) > 1 else "RETRY"
45
+
46
+ return prompt_template.format(
47
+ system_time=datetime.now(tz=UTC).isoformat(),
48
+ workers=", ".join(WORKERS),
49
+ members=", ".join(MEMBERS),
50
+ worker_options=", ".join([f'"{w}"' for w in WORKERS]),
51
+ example_worker_1=example_worker_1,
52
+ example_worker_2=example_worker_2,
53
+ correct_verdict=correct_verdict,
54
+ retry_verdict=retry_verdict
55
+ )
56
+
57
+
58
+ def load_chat_model(fully_specified_name: str) -> BaseChatModel:
59
+ """Load a chat model from a fully specified name.
60
+
61
+ Args:
62
+ fully_specified_name (str): String in the format 'provider/model'.
63
+ """
64
+ provider, model = fully_specified_name.split("/", maxsplit=1)
65
+
66
+ # Special handling for Google Genai models to ensure they're configured for async
67
+ if provider == "google_genai":
68
+ from langchain_google_genai import ChatGoogleGenerativeAI
69
+
70
+ # Make sure we have the API key
71
+ if not os.environ.get("GOOGLE_API_KEY"):
72
+ raise ValueError("GOOGLE_API_KEY environment variable is required for google_genai models")
73
+
74
+ return ChatGoogleGenerativeAI(model=model)
75
+ else:
76
+ return init_chat_model(model, model_provider=provider)