Spaces:
Build error
Build error
Refactor agent.py and graph.py to enhance agent functionality and logging. Introduce Configuration class for managing parameters, improve state handling in AgentRunner, and update agent graph to support step logging and user interaction. Add new tests for agent capabilities and update requirements for code formatting tools.
401799d
unverified
"""Define the agent graph and its components.""" | |
import logging | |
import os | |
import uuid | |
from typing import Dict, List, Optional, TypedDict, Union, cast | |
import yaml | |
from langchain_core.language_models import BaseChatModel | |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnableConfig | |
from langgraph.graph import END, StateGraph | |
from langgraph.prebuilt import ToolExecutor, ToolNode | |
from langgraph.types import interrupt | |
from smolagents import CodeAgent, LiteLLMModel, ToolCallingAgent | |
from configuration import Configuration | |
from tools import tools | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Enable LiteLLM debug logging only if environment variable is set | |
import litellm | |
if os.getenv("LITELLM_DEBUG", "false").lower() == "true": | |
litellm.set_verbose = True | |
logger.setLevel(logging.DEBUG) | |
else: | |
litellm.set_verbose = False | |
logger.setLevel(logging.INFO) | |
# Configure LiteLLM to drop unsupported parameters | |
litellm.drop_params = True | |
# Load default prompt templates from local file | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
prompts_dir = os.path.join(current_dir, "prompts") | |
yaml_path = os.path.join(prompts_dir, "code_agent.yaml") | |
with open(yaml_path, "r") as f: | |
prompt_templates = yaml.safe_load(f) | |
# Initialize the model and agent using configuration | |
config = Configuration() | |
model = LiteLLMModel( | |
api_base=config.api_base, | |
api_key=config.api_key, | |
model_id=config.model_id, | |
) | |
agent = CodeAgent( | |
add_base_tools=True, | |
max_steps=1, # Execute one step at a time | |
model=model, | |
prompt_templates=prompt_templates, | |
tools=tools, | |
verbosity_level=logging.DEBUG, | |
) | |
class AgentState(TypedDict): | |
"""State for the agent graph.""" | |
messages: List[Union[HumanMessage, AIMessage, SystemMessage]] | |
question: str | |
answer: Optional[str] | |
step_logs: List[Dict] | |
is_complete: bool | |
step_count: int | |
class AgentNode: | |
"""Node that runs the agent.""" | |
def __init__(self, agent: CodeAgent): | |
"""Initialize the agent node with an agent.""" | |
self.agent = agent | |
def __call__( | |
self, state: AgentState, config: Optional[RunnableConfig] = None | |
) -> AgentState: | |
"""Run the agent on the current state.""" | |
# Log current state | |
logger.info("Current state before processing:") | |
logger.info(f"Messages: {state['messages']}") | |
logger.info(f"Question: {state['question']}") | |
logger.info(f"Answer: {state['answer']}") | |
# Get configuration | |
cfg = Configuration.from_runnable_config(config) | |
logger.info(f"Using configuration: {cfg}") | |
# Log execution start | |
logger.info("Starting agent execution") | |
# Run the agent | |
result = self.agent.run(state["question"]) | |
# Log result | |
logger.info(f"Agent execution result type: {type(result)}") | |
logger.info(f"Agent execution result value: {result}") | |
# Update state | |
new_state = state.copy() | |
new_state["messages"].append(AIMessage(content=result)) | |
new_state["answer"] = result | |
new_state["step_count"] += 1 | |
# Log updated state | |
logger.info("Updated state after processing:") | |
logger.info(f"Messages: {new_state['messages']}") | |
logger.info(f"Question: {new_state['question']}") | |
logger.info(f"Answer: {new_state['answer']}") | |
return new_state | |
class StepCallbackNode: | |
"""Node that handles step callbacks and user interaction.""" | |
def __call__( | |
self, state: AgentState, config: Optional[RunnableConfig] = None | |
) -> AgentState: | |
"""Handle step callback and user interaction.""" | |
# Get configuration | |
cfg = Configuration.from_runnable_config(config) | |
# Log the step | |
step_log = { | |
"step": state["step_count"], | |
"messages": [msg.content for msg in state["messages"]], | |
"question": state["question"], | |
"answer": state["answer"], | |
} | |
state["step_logs"].append(step_log) | |
try: | |
# Use interrupt for user input | |
user_input = interrupt( | |
"Press 'c' to continue, 'q' to quit, or 'i' for more info: " | |
) | |
if user_input.lower() == "q": | |
state["is_complete"] = True | |
return state | |
elif user_input.lower() == "i": | |
logger.info(f"Current step: {state['step_count']}") | |
logger.info(f"Question: {state['question']}") | |
logger.info(f"Current answer: {state['answer']}") | |
return self(state, config) # Recursively call for new input | |
elif user_input.lower() == "c": | |
return state | |
else: | |
logger.warning("Invalid input. Please use 'c', 'q', or 'i'.") | |
return self(state, config) # Recursively call for new input | |
except Exception as e: | |
logger.warning(f"Error during interrupt: {str(e)}") | |
return state | |
def build_agent_graph(agent: AgentNode) -> StateGraph: | |
"""Build the agent graph.""" | |
# Initialize the graph | |
workflow = StateGraph(AgentState) | |
# Add nodes | |
workflow.add_node("agent", agent) | |
workflow.add_node("callback", StepCallbackNode()) | |
# Add edges | |
workflow.add_edge("agent", "callback") | |
workflow.add_conditional_edges( | |
"callback", | |
lambda x: END if x["is_complete"] else "agent", | |
{True: END, False: "agent"}, | |
) | |
# Set entry point | |
workflow.set_entry_point("agent") | |
return workflow.compile() | |
# Initialize the agent graph | |
agent_graph = build_agent_graph(AgentNode(agent)) | |