import asyncio import logging import os import uuid # for generating thread IDs for checkpointer from typing import AsyncIterator, Optional, TypedDict from dotenv import find_dotenv, load_dotenv from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph from smolagents import CodeAgent, LiteLLMModel from smolagents.memory import ActionStep, FinalAnswerStep from smolagents.monitoring import LogLevel # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) # Load environment variables load_dotenv(find_dotenv()) # Get required environment variables with validation API_BASE = os.getenv("API_BASE") API_KEY = os.getenv("API_KEY") MODEL_ID = os.getenv("MODEL_ID") if not all([API_BASE, API_KEY, MODEL_ID]): raise ValueError( "Missing required environment variables: API_BASE, API_KEY, MODEL_ID" ) # Define the state types for our graph class AgentState(TypedDict): task: str current_step: Optional[dict] # Store serializable dict instead of ActionStep error: Optional[str] answer_text: Optional[str] # Initialize model with error handling try: model = LiteLLMModel( api_base=API_BASE, api_key=API_KEY, model_id=MODEL_ID, ) except Exception as e: logger.error(f"Failed to initialize model: {str(e)}") raise # Initialize agent with error handling try: agent = CodeAgent( add_base_tools=True, additional_authorized_imports=["pandas", "numpy"], max_steps=10, model=model, tools=[], step_callbacks=None, verbosity_level=LogLevel.ERROR, ) agent.logger.console.width = 66 except Exception as e: logger.error(f"Failed to initialize agent: {str(e)}") raise async def process_step(state: AgentState) -> AgentState: """Process a single step of the agent's execution.""" try: # Clear previous step results before running agent.run state["current_step"] = None state["answer_text"] = None state["error"] = None steps = agent.run( task=state["task"], additional_args=None, images=None, max_steps=1, # Process one step at a time stream=True, reset=False, # Maintain agent's internal state across process_step calls ) for step in steps: if isinstance(step, ActionStep): # Convert ActionStep to serializable dict using the correct attributes state["current_step"] = { "step_number": step.step_number, "model_output": step.model_output, "observations": step.observations, "tool_calls": [ {"name": tc.name, "arguments": tc.arguments} for tc in (step.tool_calls or []) ], "action_output": step.action_output, } logger.info(f"Processed action step {step.step_number}") elif isinstance(step, FinalAnswerStep): state["answer_text"] = step.final_answer logger.info("Processed final answer") logger.debug(f"Final answer details: {step}") logger.info(f"Extracted answer text: {state['answer_text']}") # Return immediately when we get a final answer return state # If loop finishes without FinalAnswerStep, return current state return state except Exception as e: state["error"] = str(e) logger.error(f"Error during agent execution step: {str(e)}") return state def should_continue(state: AgentState) -> bool: """Determine if the agent should continue processing steps.""" # Continue if we don't have an answer_text and no error continue_execution = state.get("answer_text") is None and state.get("error") is None logger.debug( f"Checking should_continue: answer_text={state.get('answer_text') is not None}, error={state.get('error') is not None} -> Continue={continue_execution}" ) return continue_execution # Build the LangGraph graph once with persistence memory = MemorySaver() builder = StateGraph(AgentState) builder.add_node("process_step", process_step) builder.add_edge(START, "process_step") builder.add_conditional_edges( "process_step", should_continue, {True: "process_step", False: END} ) graph = builder.compile(checkpointer=memory) async def stream_execution(task: str, thread_id: str) -> AsyncIterator[AgentState]: """Stream the execution of the agent.""" if not task: raise ValueError("Task cannot be empty") logger.info(f"Initializing agent execution for task: {task}") # Initialize the state initial_state: AgentState = { "task": task, "current_step": None, "error": None, "answer_text": None, } # Pass thread_id via the config dict so the checkpointer can persist state async for state in graph.astream( initial_state, {"configurable": {"thread_id": thread_id}} ): yield state # Propagate error immediately if it occurs without an answer if state.get("error") and not state.get("answer_text"): logger.error(f"Propagating error from stream: {state['error']}") raise Exception(state["error"]) async def run_with_streaming(task: str, thread_id: str) -> dict: """Run the agent with streaming output and return the results.""" last_state = None steps = [] error = None final_answer_text = None try: logger.info(f"Starting execution run for task: {task}") async for state in stream_execution(task, thread_id): last_state = state if current_step := state.get("current_step"): if not steps or steps[-1]["step_number"] != current_step["step_number"]: steps.append(current_step) # Keep print here for direct user feedback during streaming print(f"\nStep {current_step['step_number']}:") print(f"Model Output: {current_step['model_output']}") print(f"Observations: {current_step['observations']}") if current_step.get("tool_calls"): print("Tool Calls:") for tc in current_step["tool_calls"]: print(f" - {tc['name']}: {tc['arguments']}") if current_step.get("action_output"): print(f"Action Output: {current_step['action_output']}") # After the stream is finished, process the last state logger.info("Stream finished.") if last_state: # LangGraph streams dicts where keys are node names, values are state dicts node_name = list(last_state.keys())[0] actual_state = last_state.get(node_name) if actual_state: final_answer_text = actual_state.get("answer_text") error = actual_state.get("error") logger.info( f"Final answer text extracted from last state: {final_answer_text}" ) logger.info(f"Error extracted from last state: {error}") # Ensure steps list is consistent with the final state if needed last_step_in_state = actual_state.get("current_step") if last_step_in_state and ( not steps or steps[-1]["step_number"] != last_step_in_state["step_number"] ): logger.debug("Adding last step from final state to steps list.") steps.append(last_step_in_state) else: logger.warning( "Could not find actual state dictionary within last_state." ) return {"steps": steps, "final_answer": final_answer_text, "error": error} except Exception as e: import traceback logger.error( f"Exception during run_with_streaming: {str(e)}\n{traceback.format_exc()}" ) # Attempt to return based on the last known state even if exception occurred outside stream final_answer_text = None error_msg = str(e) if last_state: node_name = list(last_state.keys())[0] actual_state = last_state.get(node_name) if actual_state: final_answer_text = actual_state.get("answer_text") return {"steps": steps, "final_answer": final_answer_text, "error": error_msg} def main(task: str, thread_id: str = str(uuid.uuid4())): logger.info( f"Starting agent run from __main__ for task: '{task}' with thread_id: {thread_id}" ) result = asyncio.run(run_with_streaming(task, thread_id)) logger.info("Agent run finished.") # Print final results print("\n--- Execution Results ---") print(f"Number of Steps: {len(result.get('steps', []))}") # Optionally print step details # for i, step in enumerate(result.get('steps', [])): # print(f"Step {i+1} Details: {step}") print(f"Final Answer: {result.get('final_answer') or 'Not found'}") if err := result.get("error"): print(f"Error: {err}") return result.get("final_answer") if __name__ == "__main__": # Example Usage task_to_run = "What is the capital of France?" thread_id = str(uuid.uuid4()) # Generate a unique thread ID for this run final_answer = main(task_to_run, thread_id) print(f"Final Answer: {final_answer}")