mjschock's picture
Integrate telemetry support in main.py by registering the SmolagentsInstrumentor for enhanced monitoring. Update the main function to simplify question enhancement instructions and improve clarity. Modify extract_final_answer utility to prioritize 'final_answer_text' in results. Update requirements.txt to include telemetry dependencies for smolagents.
c43be1d unverified
raw
history blame
11.1 kB
import asyncio
import importlib
import logging
import os
import time
import uuid # for generating thread IDs for checkpointer
from typing import AsyncIterator, Optional, TypedDict
import litellm
import yaml
from dotenv import find_dotenv, load_dotenv
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from smolagents import CodeAgent, LiteLLMModel
from smolagents.memory import ActionStep, FinalAnswerStep
from smolagents.monitoring import LogLevel
from agents import create_data_analysis_agent, create_media_agent, create_web_agent
from prompts import MANAGER_SYSTEM_PROMPT
from tools import perform_calculation, web_search
from utils import extract_final_answer
from phoenix.otel import register
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
litellm._turn_on_debug()
register()
SmolagentsInstrumentor().instrument()
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv(find_dotenv())
# Get required environment variables with validation
API_BASE = os.getenv("API_BASE")
API_KEY = os.getenv("API_KEY")
MODEL_ID = os.getenv("MODEL_ID")
if not all([API_BASE, API_KEY, MODEL_ID]):
raise ValueError(
"Missing required environment variables: API_BASE, API_KEY, MODEL_ID"
)
# Define the state types for our graph
class AgentState(TypedDict):
task: str
current_step: Optional[dict] # Store serializable dict instead of ActionStep
error: Optional[str]
answer_text: Optional[str]
# Initialize model with error handling
try:
model = LiteLLMModel(
api_base=API_BASE,
api_key=API_KEY,
model_id=MODEL_ID,
)
except Exception as e:
logger.error(f"Failed to initialize model: {str(e)}")
raise
web_agent = create_web_agent(model)
data_agent = create_data_analysis_agent(model)
media_agent = create_media_agent(model)
tools = [
# DuckDuckGoSearchTool(max_results=3),
# VisitWebpageTool(max_output_length=1000),
web_search,
perform_calculation,
]
# Initialize agent with error handling
try:
prompt_templates = yaml.safe_load(
importlib.resources.files("smolagents.prompts")
.joinpath("code_agent.yaml")
.read_text()
)
# prompt_templates["system_prompt"] = MANAGER_SYSTEM_PROMPT
agent = CodeAgent(
add_base_tools=True,
additional_authorized_imports=[
"json",
"pandas",
"numpy",
"re",
],
# max_steps=10,
managed_agents=[web_agent, data_agent, media_agent],
model=model,
prompt_templates=prompt_templates,
tools=tools,
step_callbacks=None,
verbosity_level=LogLevel.ERROR,
)
agent.logger.console.width = 66
agent.visualize()
tools = agent.tools
print(f"Tools: {tools}")
except Exception as e:
logger.error(f"Failed to initialize agent: {str(e)}")
raise
async def process_step(state: AgentState) -> AgentState:
"""Process a single step of the agent's execution."""
try:
# Clear previous step results before running agent.run
state["current_step"] = None
state["answer_text"] = None
state["error"] = None
steps = agent.run(
task=state["task"],
additional_args=None,
images=None,
# max_steps=1, # Process one step at a time
stream=True,
reset=False, # Maintain agent's internal state across process_step calls
)
for step in steps:
if isinstance(step, ActionStep):
# Convert ActionStep to serializable dict using the correct attributes
state["current_step"] = {
"step_number": step.step_number,
"model_output": step.model_output,
"observations": step.observations,
"tool_calls": [
{"name": tc.name, "arguments": tc.arguments}
for tc in (step.tool_calls or [])
],
"action_output": step.action_output,
}
logger.info(f"Processed action step {step.step_number}")
logger.info(f"Step {step.step_number} details: {step}")
logger.info(f"Sleeping for 60 seconds...")
time.sleep(60)
elif isinstance(step, FinalAnswerStep):
state["answer_text"] = step.final_answer
logger.info("Processed final answer")
logger.debug(f"Final answer details: {step}")
logger.info(f"Extracted answer text: {state['answer_text']}")
# Return immediately when we get a final answer
return state
# If loop finishes without FinalAnswerStep, return current state
return state
except Exception as e:
state["error"] = str(e)
logger.error(f"Error during agent execution step: {str(e)}")
return state
def should_continue(state: AgentState) -> bool:
"""Determine if the agent should continue processing steps."""
# Continue if we don't have an answer_text and no error
continue_execution = state.get("answer_text") is None and state.get("error") is None
logger.debug(
f"Checking should_continue: answer_text={state.get('answer_text') is not None}, error={state.get('error') is not None} -> Continue={continue_execution}"
)
return continue_execution
# Build the LangGraph graph once with persistence
memory = MemorySaver()
builder = StateGraph(AgentState)
builder.add_node("process_step", process_step)
builder.add_edge(START, "process_step")
builder.add_conditional_edges(
"process_step", should_continue, {True: "process_step", False: END}
)
graph = builder.compile(checkpointer=memory)
async def stream_execution(task: str, thread_id: str) -> AsyncIterator[AgentState]:
"""Stream the execution of the agent."""
if not task:
raise ValueError("Task cannot be empty")
logger.info(f"Initializing agent execution for task: {task}")
# Initialize the state
initial_state: AgentState = {
"task": task,
"current_step": None,
"error": None,
"answer_text": None,
}
# Pass thread_id via the config dict so the checkpointer can persist state
async for state in graph.astream(
initial_state, {"configurable": {"thread_id": thread_id}}
):
yield state
# Propagate error immediately if it occurs without an answer
if state.get("error") and not state.get("answer_text"):
logger.error(f"Propagating error from stream: {state['error']}")
raise Exception(state["error"])
async def run_with_streaming(task: str, thread_id: str) -> dict:
"""Run the agent with streaming output and return the results."""
last_state = None
steps = []
error = None
final_answer_text = None
try:
logger.info(f"Starting execution run for task: {task}")
async for state in stream_execution(task, thread_id):
last_state = state
if current_step := state.get("current_step"):
if not steps or steps[-1]["step_number"] != current_step["step_number"]:
steps.append(current_step)
# Keep print here for direct user feedback during streaming
print(f"\nStep {current_step['step_number']}:")
print(f"Model Output: {current_step['model_output']}")
print(f"Observations: {current_step['observations']}")
if current_step.get("tool_calls"):
print("Tool Calls:")
for tc in current_step["tool_calls"]:
print(f" - {tc['name']}: {tc['arguments']}")
if current_step.get("action_output"):
print(f"Action Output: {current_step['action_output']}")
# After the stream is finished, process the last state
logger.info("Stream finished.")
if last_state:
# LangGraph streams dicts where keys are node names, values are state dicts
node_name = list(last_state.keys())[0]
actual_state = last_state.get(node_name)
if actual_state:
final_answer_text = actual_state.get("answer_text")
error = actual_state.get("error")
logger.info(
f"Final answer text extracted from last state: {final_answer_text}"
)
logger.info(f"Error extracted from last state: {error}")
# Ensure steps list is consistent with the final state if needed
last_step_in_state = actual_state.get("current_step")
if last_step_in_state and (
not steps
or steps[-1]["step_number"] != last_step_in_state["step_number"]
):
logger.debug("Adding last step from final state to steps list.")
steps.append(last_step_in_state)
else:
logger.warning(
"Could not find actual state dictionary within last_state."
)
return {"steps": steps, "final_answer": final_answer_text, "error": error}
except Exception as e:
import traceback
logger.error(
f"Exception during run_with_streaming: {str(e)}\n{traceback.format_exc()}"
)
# Attempt to return based on the last known state even if exception occurred outside stream
final_answer_text = None
error_msg = str(e)
if last_state:
node_name = list(last_state.keys())[0]
actual_state = last_state.get(node_name)
if actual_state:
final_answer_text = actual_state.get("answer_text")
return {"steps": steps, "final_answer": final_answer_text, "error": error_msg}
def main(task: str, thread_id: str = str(uuid.uuid4())):
# Enhance the question with minimal instructions
enhanced_question = f"""
GAIA Question: {task}
Please solve this multi-step reasoning problem by:
1. Breaking it down into logical steps
2. Using specialized agents when needed
3. Providing the final answer in the exact format requested
"""
logger.info(
f"Starting agent run from __main__ for task: '{task}' with thread_id: {thread_id}"
)
result = asyncio.run(run_with_streaming(enhanced_question, thread_id))
logger.info("Agent run finished.")
logger.info(f"Result: {result}")
return extract_final_answer(result)
if __name__ == "__main__":
# Example Usage
task_to_run = "What is the capital of France?"
thread_id = str(uuid.uuid4()) # Generate a unique thread ID for this run
final_answer = main(task_to_run, thread_id)
print(f"Final Answer: {final_answer}")