File size: 5,888 Bytes
401799d
 
81d00fe
bc3bc22
401799d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81d00fe
 
 
 
 
401799d
 
 
 
 
 
 
 
 
 
bc3bc22
 
 
401799d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81d00fe
401799d
 
 
81d00fe
401799d
 
 
 
 
81d00fe
 
401799d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81d00fe
401799d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81d00fe
401799d
81d00fe
 
401799d
 
 
 
 
 
 
 
 
 
81d00fe
401799d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""Define the agent graph and its components."""

import logging
import os
import uuid
from typing import Dict, List, Optional, TypedDict, Union, cast

import yaml
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolExecutor, ToolNode
from langgraph.types import interrupt
from smolagents import CodeAgent, LiteLLMModel, ToolCallingAgent

from configuration import Configuration
from tools import tools

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Enable LiteLLM debug logging only if environment variable is set
import litellm

if os.getenv("LITELLM_DEBUG", "false").lower() == "true":
    litellm.set_verbose = True
    logger.setLevel(logging.DEBUG)
else:
    litellm.set_verbose = False
    logger.setLevel(logging.INFO)

# Configure LiteLLM to drop unsupported parameters
litellm.drop_params = True

# Load default prompt templates from local file
current_dir = os.path.dirname(os.path.abspath(__file__))
prompts_dir = os.path.join(current_dir, "prompts")
yaml_path = os.path.join(prompts_dir, "code_agent.yaml")

with open(yaml_path, "r") as f:
    prompt_templates = yaml.safe_load(f)

# Initialize the model and agent using configuration
config = Configuration()
model = LiteLLMModel(
    api_base=config.api_base,
    api_key=config.api_key,
    model_id=config.model_id,
)

agent = CodeAgent(
    add_base_tools=True,
    max_steps=1,  # Execute one step at a time
    model=model,
    prompt_templates=prompt_templates,
    tools=tools,
    verbosity_level=logging.DEBUG,
)


class AgentState(TypedDict):
    """State for the agent graph."""

    messages: List[Union[HumanMessage, AIMessage, SystemMessage]]
    question: str
    answer: Optional[str]
    step_logs: List[Dict]
    is_complete: bool
    step_count: int


class AgentNode:
    """Node that runs the agent."""

    def __init__(self, agent: CodeAgent):
        """Initialize the agent node with an agent."""
        self.agent = agent

    def __call__(
        self, state: AgentState, config: Optional[RunnableConfig] = None
    ) -> AgentState:
        """Run the agent on the current state."""
        # Log current state
        logger.info("Current state before processing:")
        logger.info(f"Messages: {state['messages']}")
        logger.info(f"Question: {state['question']}")
        logger.info(f"Answer: {state['answer']}")

        # Get configuration
        cfg = Configuration.from_runnable_config(config)
        logger.info(f"Using configuration: {cfg}")

        # Log execution start
        logger.info("Starting agent execution")

        # Run the agent
        result = self.agent.run(state["question"])

        # Log result
        logger.info(f"Agent execution result type: {type(result)}")
        logger.info(f"Agent execution result value: {result}")

        # Update state
        new_state = state.copy()
        new_state["messages"].append(AIMessage(content=result))
        new_state["answer"] = result
        new_state["step_count"] += 1

        # Log updated state
        logger.info("Updated state after processing:")
        logger.info(f"Messages: {new_state['messages']}")
        logger.info(f"Question: {new_state['question']}")
        logger.info(f"Answer: {new_state['answer']}")

        return new_state


class StepCallbackNode:
    """Node that handles step callbacks and user interaction."""

    def __call__(
        self, state: AgentState, config: Optional[RunnableConfig] = None
    ) -> AgentState:
        """Handle step callback and user interaction."""
        # Get configuration
        cfg = Configuration.from_runnable_config(config)

        # Log the step
        step_log = {
            "step": state["step_count"],
            "messages": [msg.content for msg in state["messages"]],
            "question": state["question"],
            "answer": state["answer"],
        }
        state["step_logs"].append(step_log)

        try:
            # Use interrupt for user input
            user_input = interrupt(
                "Press 'c' to continue, 'q' to quit, or 'i' for more info: "
            )

            if user_input.lower() == "q":
                state["is_complete"] = True
                return state
            elif user_input.lower() == "i":
                logger.info(f"Current step: {state['step_count']}")
                logger.info(f"Question: {state['question']}")
                logger.info(f"Current answer: {state['answer']}")
                return self(state, config)  # Recursively call for new input
            elif user_input.lower() == "c":
                return state
            else:
                logger.warning("Invalid input. Please use 'c', 'q', or 'i'.")
                return self(state, config)  # Recursively call for new input

        except Exception as e:
            logger.warning(f"Error during interrupt: {str(e)}")
            return state


def build_agent_graph(agent: AgentNode) -> StateGraph:
    """Build the agent graph."""
    # Initialize the graph
    workflow = StateGraph(AgentState)

    # Add nodes
    workflow.add_node("agent", agent)
    workflow.add_node("callback", StepCallbackNode())

    # Add edges
    workflow.add_edge("agent", "callback")
    workflow.add_conditional_edges(
        "callback",
        lambda x: END if x["is_complete"] else "agent",
        {True: END, False: "agent"},
    )

    # Set entry point
    workflow.set_entry_point("agent")

    return workflow.compile()


# Initialize the agent graph
agent_graph = build_agent_graph(AgentNode(agent))