File size: 16,200 Bytes
ac6a4ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
"""Supervisor node implementation for the agent supervisor system."""

from typing import Dict, List, Literal, Optional, Union, Type, cast

from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph, START, END
from langgraph.types import Command

from react_agent.configuration import Configuration
from react_agent.state import WORKERS, MEMBERS, ROUTING, VERDICTS, State, Router
from react_agent.utils import load_chat_model, format_system_prompt, get_message_text
from react_agent import prompts


# Compile-time type definitions
SupervisorDestinations = Literal["planner", "critic", "researcher", "coder", "final_answer", "__end__"]


def supervisor_node(state: State) -> Command[SupervisorDestinations]:
    """Supervising LLM that decides which specialized agent should act next.

    Args:
        state: The current state with messages

    Returns:
        Command with routing information
    """
    # Get configuration to use supervisor_model
    configuration = Configuration.from_context()
    
    # Track steps to prevent infinite loops
    steps_taken = state.get("steps_taken", 0)
    steps_taken += 1
    state_updates = {"steps_taken": steps_taken}
    
    # Check if we've hit our step limit
    if steps_taken >= configuration.recursion_limit - 5:  # Buffer of 5 steps
        # Extract the best answer we have from context if possible
        context = state.get("context", {})
        answer = extract_best_answer_from_context(context)
        
        return Command(
            goto="final_answer",
            update={
                "messages": [
                    HumanMessage(
                        content=f"Maximum steps ({steps_taken}) reached. Extracting best answer from available information.",
                        name="supervisor"
                    )
                ],
                "draft_answer": f"FINAL ANSWER: {answer}",
                "retry_exhausted": True,  # Flag to indicate we've exhausted retries
                "steps_taken": steps_taken
            }
        )
    
    # Safety check - prevent infinite loops by forcing termination after too many retry steps
    retry_count = state.get("retry_count", 0)
    max_retries = 2  # Maximum number of allowed retries
    
    if retry_count > max_retries:
        # Extract the best answer we have from context if possible
        context = state.get("context", {})
        answer = extract_best_answer_from_context(context)
        
        return Command(
            goto="final_answer",
            update={
                "messages": [
                    HumanMessage(
                        content=f"Maximum retries ({max_retries}) reached. Extracting best answer from available information.",
                        name="supervisor"
                    )
                ],
                "draft_answer": f"FINAL ANSWER: {answer}",
                "retry_exhausted": True,  # Flag to indicate we've exhausted retries
                "steps_taken": steps_taken
            }
        )
        
    # Check if we need a plan
    if not state.get("plan"):
        return Command(
            goto="planner",
            update={
                **state_updates
            }
        )
    
    # Validate that the plan has at least one step
    plan = state.get("plan")
    if not plan.get("steps") or len(plan.get("steps", [])) == 0:
        # Plan has no steps, go back to planner with explicit instructions
        return Command(
            goto="planner",
            update={
                "messages": [
                    HumanMessage(
                        content="Previous plan had 0 steps. Please create a plan with at least 1 step to solve the user's question.",
                        name="supervisor"
                    )
                ],
                "plan": None,
                **state_updates
            }
        )
    
    # Check if we have a critic verdict that requires replanning
    critic_verdict = state.get("critic_verdict")
    if critic_verdict:
        if critic_verdict.get("verdict") == VERDICTS[0]:  # CORRECT
            # Final answer is approved, navigate to the final_answer node
            # This will generate a polished response before ending
            return Command(
                goto="final_answer",
                update={
                    "messages": [
                        HumanMessage(
                            content="Answer approved by critic. Generating final response.",
                            name="supervisor"
                        )
                    ]
                }
            )
        elif critic_verdict.get("verdict") == VERDICTS[1]:  # RETRY
            # IMPORTANT: Get the current retry count BEFORE incrementing
            current_retry_count = state.get("retry_count", 0)
            
            # Check if we're at the maximum allowed retries
            if current_retry_count >= max_retries:
                # Extract best answer and go to final_answer
                context = state.get("context", {})
                answer = extract_best_answer_from_context(context)
                
                return Command(
                    goto="final_answer",
                    update={
                        "messages": [
                            HumanMessage(
                                content=f"Maximum retries ({max_retries}) reached. Proceeding with best available answer.",
                                name="supervisor"
                            )
                        ],
                        "draft_answer": f"FINAL ANSWER: {answer}",
                        "retry_exhausted": True  # Flag to indicate we've exhausted retries
                    }
                )
            
            # Reset the plan but KEEP the context from previous iterations
            context = state.get("context", {})
            worker_results = state.get("worker_results", {})
            
            # Get the critic's reason for rejection, if any
            reason = critic_verdict.get("reason", "")
            if not reason or reason.strip() == "\"":
                reason = "Answer did not meet format requirements"
                
            # Check if this is a formatting issue
            format_issues = [
                "format", "concise", "explanation", "not formatted", 
                "instead of just", "contains explanations", "FINAL ANSWER"
            ]
            is_format_issue = any(issue in reason.lower() for issue in format_issues)
            
            # If we have enough information but the format is wrong, go directly to final answer
            has_sufficient_info = has_sufficient_information(state)
            
            if is_format_issue and has_sufficient_info and current_retry_count >= 0:
                # We have information but formatting is wrong - skip planning and go to final answer
                return Command(
                    goto="final_answer",
                    update={
                        "messages": [
                            HumanMessage(
                                content="We have sufficient information but formatting issues. Generating properly formatted answer.",
                                name="supervisor"
                            )
                        ],
                        "retry_count": current_retry_count + 1  # Still increment retry count
                    }
                )
            
            # Increment the retry counter
            next_retry_count = current_retry_count + 1
            
            return Command(
                goto="planner", 
                update={
                    "plan": None, 
                    "current_step_index": None,
                    "draft_answer": None,
                    "critic_verdict": None,
                    # Keep the context and worker_results
                    "context": context,
                    "worker_results": worker_results,
                    # Track retries - IMPORTANT: store the incremented count
                    "retry_count": next_retry_count,
                    # Add a message about the retry (using the INCREMENTED count)
                    "messages": [
                        HumanMessage(
                            content=f"Retrying with new plan (retry #{next_retry_count}). Reason: {reason}",
                            name="supervisor"
                        )
                    ]
                }
            )
    
    # Get the current step from the plan
    plan = state["plan"]
    current_step_index = state.get("current_step_index", 0)
    
    # Check if we've completed all steps
    if current_step_index >= len(plan["steps"]):
        # Use context to compile the draft answer
        context = state.get("context", {})
        
        # Combine the most recent worker outputs as the draft answer
        worker_results = []
        for worker in WORKERS:
            if worker in context:
                worker_results.append(f"**{worker.title()}**: {context[worker]}")
        
        # Compile the draft answer from all worker outputs
        draft_content = "\n\n".join(worker_results)
        
        # Send to the critic for evaluation
        return Command(
            goto="critic",
            update={
                "draft_answer": draft_content,
                # Add a message about moving to evaluation
                "messages": [
                    HumanMessage(
                        content="All steps completed. Evaluating the answer.",
                        name="supervisor"
                    )
                ]
            }
        )
    
    # Get the current step
    current_step = plan["steps"][current_step_index]
    worker = current_step["worker"]
    instruction = current_step["instruction"]
    
    # Extract only the most relevant context for the current worker and task
    context_info = ""
    if state.get("context"):
        # Filter context by relevance to the current task
        relevant_context = {}
        
        # For the coder, extract numerical data and parameters from researcher
        if worker == "coder" and "researcher" in state["context"]:
            relevant_context["researcher"] = state["context"]["researcher"]
        
        # For the researcher, previous coder calculations might be relevant
        if worker == "researcher" and "coder" in state["context"]:
            # Only include numerical results from coder, not code snippets
            coder_content = state["context"]["coder"]
            if len(coder_content) < 100:  # Only short results are likely just numbers
                relevant_context["coder"] = coder_content
        
        # Format the relevant context items
        context_items = []
        for key, value in relevant_context.items():
            # Summarize if value is too long
            if len(value) > 200:
                # Find first sentence or up to 200 chars
                summary = value[:200]
                if '.' in summary:
                    summary = summary.split('.')[0] + '.'
                context_items.append(f"Previous {key} found: {summary}...")
            else:
                context_items.append(f"Previous {key} found: {value}")
        
        if context_items:
            context_info = "\n\nRelevant context: " + "\n".join(context_items)
    
    # Enhance the instruction with context
    enhanced_instruction = f"{instruction}{context_info}"
    
    # Add guidance based on worker type
    if worker == "coder":
        enhanced_instruction += "\nProvide both your calculation method AND the final result value."
    elif worker == "researcher":
        enhanced_instruction += "\nFocus on gathering factual information related to the task."
    
    # Add the instruction to the messages
    messages_update = [
        HumanMessage(
            content=f"Step {current_step_index + 1}: {enhanced_instruction}",
            name="supervisor"
        )
    ]
    
    # Cast worker to appropriate type to satisfy type checking
    worker_destination = cast(SupervisorDestinations, worker)
    
    # Move to the appropriate worker
    return Command(
        goto=worker_destination,
        update={
            "messages": messages_update,
            "next": worker,  # For backward compatibility
            **state_updates
        }
    )

def extract_best_answer_from_context(context):
    """Extract the best available answer from context.
    
    This is a generic function to extract answers from any type of question context.
    It progressively tries different strategies to find a suitable answer.
    
    Args:
        context: The state context containing worker outputs
        
    Returns:
        Best answer found or "unknown" if nothing suitable is found
    """
    answer = "unknown"
    
    # First check if the coder already provided a properly formatted answer
    if "coder" in context:
        coder_content = context["coder"]
        
        # Look for "FINAL ANSWER: X" pattern in the coder output
        import re
        answer_match = re.search(r"FINAL ANSWER:\s*(.*?)(?:\n|$)", coder_content, re.IGNORECASE)
        if answer_match:
            return answer_match.group(1).strip()
    
    # If no answer in coder output, check researcher content
    if "researcher" in context:
        researcher_content = context["researcher"]
        
        # Look for lists in the researcher content (common pattern)
        import re
        
        # Look for bulleted list items
        list_items = re.findall(r"[-•*]\s+([^:\n]+)", researcher_content)
        if list_items:
            # Format as comma-separated list
            answer = ",".join(item.strip() for item in list_items)
            return answer
            
        # Look for emphasized/bold items which might be key information
        bold_items = re.findall(r"\*\*([^*]+)\*\*", researcher_content)
        if bold_items:
            # Join the important items as a comma-separated list
            processed_items = []
            for item in bold_items:
                # Remove common filler words and clean up the item
                clean_item = re.sub(r'(^|\s)(a|an|the|is|are|was|were|be|been)(\s|$)', ' ', item)
                clean_item = clean_item.strip()
                if clean_item and len(clean_item) < 30:  # Only include reasonably short items
                    processed_items.append(clean_item)
            
            if processed_items:
                answer = ",".join(processed_items)
                return answer
    
    # If we still don't have an answer, try to extract common entities
    combined_content = ""
    for worker_type, content in context.items():
        combined_content += " " + content
    
    # Look for numbers in the content
    import re
    numbers = re.findall(r'\b(\d+(?:\.\d+)?)\b', combined_content)
    if numbers:
        answer = numbers[0]  # Use the first number found
    
    return answer

def has_sufficient_information(state):
    """Determine if we have enough information to generate a final answer.
    
    Args:
        state: The current conversation state
        
    Returns:
        Boolean indicating if we have sufficient information
    """
    context = state.get("context", {})
    
    # If we have both researcher and coder outputs, we likely have enough info
    if "researcher" in context and "coder" in context:
        return True
        
    # If we have a substantial researcher output, that might be enough
    if "researcher" in context and len(context["researcher"]) > 150:
        return True
        
    # If we have any worker output that contains lists or formatted data
    for worker, content in context.items():
        if content and (
            "- " in content or  # Bullet point
            "•" in content or   # Bullet point
            "*" in content or   # Emphasis or bullet
            ":" in content      # Definition or explanation
        ):
            return True
    
    return False