File size: 6,076 Bytes
1dfef0f
03aebad
 
db22cb0
037cb93
 
 
 
eb258dd
5811bc8
037cb93
9af2eae
037cb93
 
 
03aebad
 
037cb93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dfef0f
 
 
037cb93
db22cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9af2eae
 
 
 
 
db22cb0
 
9af2eae
 
 
 
 
 
 
 
 
1dfef0f
9af2eae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dfef0f
9af2eae
 
 
 
1dfef0f
9af2eae
 
 
 
 
 
037cb93
 
 
 
 
 
 
 
 
 
 
 
 
1dfef0f
 
9af2eae
 
037cb93
 
 
 
 
1dfef0f
9af2eae
 
037cb93
 
 
 
 
4c122b6
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# agent.py
import os
from dotenv import load_dotenv
from typing import TypedDict, Annotated, Sequence, Dict, Any, List
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolNode
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_community.utilities import WikipediaAPIWrapper
from langchain.agents import create_tool_calling_agent, AgentExecutor
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import operator

load_dotenv()

class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    sender: str

@tool
def wikipedia_search(query: str) -> str:
    """Search Wikipedia for information."""
    return WikipediaAPIWrapper().run(query)

@tool
def web_search(query: str, num_results: int = 3) -> list:
    """Search the web for current information."""
    return DuckDuckGoSearchResults(num_results=num_results).run(query)

@tool
def calculate(expression: str) -> str:
    """Evaluate mathematical expressions."""
    from langchain_experimental.utilities import PythonREPL
    python_repl = PythonREPL()
    return python_repl.run(expression)

def build_graph(tools: list, agent_executor: AgentExecutor) -> StateGraph:
    """Build and return the compiled workflow graph"""
    workflow = StateGraph(AgentState)
    
    def run_agent(state: AgentState) -> Dict[str, Any]:
        response = agent_executor.invoke({"messages": state["messages"]})
        return {"messages": [response["output"]]}
    
    def should_continue(state: AgentState) -> str:
        last_message = state["messages"][-1]
        return "continue" if last_message.additional_kwargs.get("tool_calls") else "end"
    
    workflow.add_node("agent", run_agent)
    workflow.add_node("tools", ToolNode(tools))
    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {"continue": "tools", "end": END}
    )
    workflow.add_edge("tools", "agent")
    
    return workflow.compile()

class AIAgent:
    def __init__(self, model_name: str = "gpt-3.5-turbo"):
        self.tools = [wikipedia_search, web_search, calculate]
        self.llm = ChatOpenAI(model=model_name, temperature=0.7)
        self.agent_executor = self._build_agent_executor()
        self.workflow = build_graph(self.tools, self.agent_executor)  # Using the standalone function

    def _build_agent_executor(self) -> AgentExecutor:
        """Build the agent executor"""
        prompt = ChatPromptTemplate.from_messages([
            ("system", "You are a helpful AI assistant. Use tools when needed."),
            MessagesPlaceholder(variable_name="messages"),
            MessagesPlaceholder(variable_name="agent_scratchpad"),
        ])
        agent = create_tool_calling_agent(self.llm, self.tools, prompt)
        return AgentExecutor(agent=agent, tools=self.tools, verbose=True)
    
    def _build_workflow(self) -> StateGraph:
        """Build and return the compiled workflow"""
        workflow = StateGraph(AgentState)
        
        workflow.add_node("agent", self._run_agent)
        workflow.add_node("tools", ToolNode(self.tools))
        
        workflow.set_entry_point("agent")
        workflow.add_conditional_edges(
            "agent",
            self._should_continue,
            {"continue": "tools", "end": END}
        )
        workflow.add_edge("tools", "agent")
        
        return workflow.compile()
    
    def _run_agent(self, state: AgentState) -> Dict[str, Any]:
        """Execute the agent"""
        response = self.agent_executor.invoke({"messages": state["messages"]})
        return {"messages": [response["output"]]}
    
    def _should_continue(self, state: AgentState) -> str:
        """Determine if the workflow should continue"""
        last_message = state["messages"][-1]
        return "continue" if last_message.additional_kwargs.get("tool_calls") else "end"
    
    def __call__(self, query: str) -> Dict[str, Any]:
        """Process a user query"""
        state = AgentState(messages=[HumanMessage(content=query)], sender="user")
        
        for output in self.workflow.stream(state):
            for key, value in output.items():
                if key == "messages":
                    for message in value:
                        if isinstance(message, BaseMessage):
                            return {
                                "response": message.content,
                                "sources": self._extract_sources(state["messages"]),
                                "steps": self._extract_steps(state["messages"])
                            }
        return {"response": "No response generated", "sources": [], "steps": []}
    
    def _extract_sources(self, messages: Sequence[BaseMessage]) -> List[str]:
        """Extract sources from tool messages"""
        return [
            f"{msg.additional_kwargs.get('name', 'unknown')}: {msg.content}"
            for msg in messages
            if hasattr(msg, 'additional_kwargs') and 'name' in msg.additional_kwargs
        ]
    
    def _extract_steps(self, messages: Sequence[BaseMessage]) -> List[str]:
        """Extract reasoning steps"""
        steps = []
        for msg in messages:
            if hasattr(msg, 'additional_kwargs') and 'tool_calls' in msg.additional_kwargs:
                for call in msg.additional_kwargs['tool_calls']:
                    steps.append(f"Used {call['function']['name']}: {call['function']['arguments']}")
        return steps

if __name__ == "__main__":
    agent = AIAgent()
    response = agent("What's the capital of France?")
    print("Response:", response["response"])
    if response["sources"]:
        print("\nSources:")
        for source in response["sources"]:
            print("-", source)
    if response["steps"]:
        print("\nSteps:")
        for step in response["steps"]:
            print("-", step)