File size: 4,319 Bytes
1dfef0f
037cb93
 
 
 
 
eb258dd
5811bc8
037cb93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dfef0f
 
 
037cb93
1dfef0f
 
 
 
 
 
 
 
 
037cb93
1dfef0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
037cb93
1dfef0f
 
 
 
 
 
 
037cb93
 
 
 
 
 
 
 
 
 
 
 
 
 
1dfef0f
 
037cb93
 
 
 
 
 
1dfef0f
037cb93
 
 
 
 
 
1dfef0f
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# agent.py
from typing import TypedDict, Annotated, Sequence
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolNode
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_community.utilities import WikipediaAPIWrapper
from langchain.agents import create_tool_calling_agent
from langchain.agents import AgentExecutor
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import operator

class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    sender: str

@tool
def wikipedia_search(query: str) -> str:
    """Search Wikipedia for information."""
    return WikipediaAPIWrapper().run(query)

@tool
def web_search(query: str, num_results: int = 3) -> list:
    """Search the web for current information."""
    return DuckDuckGoSearchResults(num_results=num_results).run(query)

@tool
def calculate(expression: str) -> str:
    """Evaluate mathematical expressions."""
    from langchain_experimental.utilities import PythonREPL
    python_repl = PythonREPL()
    return python_repl.run(expression)

def build_agent(tools: list, llm: ChatOpenAI) -> AgentExecutor:
    """Build the agent executor"""
    prompt = ChatPromptTemplate.from_messages([
        ("system", "You are a helpful AI assistant. Use tools when needed."),
        MessagesPlaceholder(variable_name="messages"),
        MessagesPlaceholder(variable_name="agent_scratchpad"),
    ])
    agent = create_tool_calling_agent(llm, tools, prompt)
    return AgentExecutor(agent=agent, tools=tools, verbose=True)

def build_graph(tools: list, agent: AgentExecutor) -> StateGraph:
    """Build the LangGraph workflow"""
    workflow = StateGraph(AgentState)
    
    # Define nodes
    workflow.add_node("agent", lambda state: {"messages": [agent.invoke(state)["output"]]})
    workflow.add_node("tools", ToolNode(tools))
    
    # Define edges
    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        lambda state: "continue" if state["messages"][-1].additional_kwargs.get("tool_calls") else "end",
        {"continue": "tools", "end": END}
    )
    workflow.add_edge("tools", "agent")
    
    return workflow.compile()

class AIAgent:
    def __init__(self, model_name: str = "gpt-4-turbo"):
        self.tools = [wikipedia_search, web_search, calculate]
        self.llm = ChatOpenAI(model=model_name, temperature=0.7)
        self.agent = build_agent(self.tools, self.llm)
        self.workflow = build_graph(self.tools, self.agent)
    
    def __call__(self, query: str) -> dict:
        """Process a user query"""
        state = AgentState(messages=[HumanMessage(content=query)], sender="user")
        
        for output in self.workflow.stream(state):
            for key, value in output.items():
                if key == "messages":
                    for message in value:
                        if isinstance(message, BaseMessage):
                            return {
                                "response": message.content,
                                "sources": self._extract_sources(state["messages"]),
                                "steps": self._extract_steps(state["messages"])
                            }
        return {"response": "No response generated", "sources": [], "steps": []}
    
    def _extract_sources(self, messages: Sequence[BaseMessage]) -> list:
        return [
            f"{msg.additional_kwargs.get('name', 'unknown')}: {msg.content}"
            for msg in messages
            if hasattr(msg, 'additional_kwargs') and 'name' in msg.additional_kwargs
        ]
    
    def _extract_steps(self, messages: Sequence[BaseMessage]) -> list:
        steps = []
        for msg in messages:
            if hasattr(msg, 'additional_kwargs') and 'tool_calls' in msg.additional_kwargs:
                for call in msg.additional_kwargs['tool_calls']:
                    steps.append(f"Used {call['function']['name']}: {call['function']['arguments']}")
        return steps

# Example usage
if __name__ == "__main__":
    agent = AIAgent()
    response = agent("What's the capital of France?")
    print(response)