File size: 2,680 Bytes
81d00fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import logging
from typing import TypedDict
from langgraph.graph import StateGraph, END
from smolagents import ToolCallingAgent, LiteLLMModel
from tools import tools
import yaml
import importlib.resources

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Define the state for our agent graph
class AgentState(TypedDict):
    messages: list
    question: str
    answer: str | None

class AgentNode:
    def __init__(self):
        # Load default prompt templates
        prompt_templates = yaml.safe_load(
            importlib.resources.files("smolagents.prompts").joinpath("toolcalling_agent.yaml").read_text()
        )
        
        # Log the default system prompt
        logger.info("Default system prompt:")
        logger.info("-" * 80)
        logger.info(prompt_templates["system_prompt"])
        logger.info("-" * 80)
        
#         # Define our custom system prompt
#         custom_system_prompt = "..."

#         # Update the system prompt in the loaded templates
#         prompt_templates["system_prompt"] = custom_system_prompt
        
        # Log our custom system prompt
        # logger.info("Custom system prompt:")
        # logger.info("-" * 80)
        # logger.info(custom_system_prompt)
        # logger.info("-" * 80)

        # In"itialize the model and agent
        self.model = LiteLLMModel(
            model="ollama/codellama",
            temperature=0.0,
            max_tokens=4096,
            top_p=0.9,
            frequency_penalty=0.0,
            presence_penalty=0.0,
            stop=["Observation:"],
        )
        
        self.agent = ToolCallingAgent(
            model=self.model,
            prompt_templates=prompt_templates,
            tools=tools
        )

    def __call__(self, state: AgentState) -> AgentState:
        try:
            # Process the question through the agent
            result = self.agent.run(state["question"])
            
            # Update the state with the answer
            state["answer"] = result
            return state
            
        except Exception as e:
            logger.error(f"Error in agent node: {str(e)}", exc_info=True)
            state["answer"] = f"Error: {str(e)}"
            return state

def build_agent_graph():
    # Create the graph
    graph = StateGraph(AgentState)
    
    # Add the agent node
    graph.add_node("agent", AgentNode())
    
    # Add edges
    graph.add_edge("agent", END)
    
    # Set the entry point
    graph.set_entry_point("agent")
    
    # Compile the graph
    return graph.compile()

# Create an instance of the compiled graph
agent_graph = build_agent_graph()