wt002 commited on
Commit
1dfef0f
·
verified ·
1 Parent(s): 5811bc8

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +48 -70
agent.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import TypedDict, Annotated, Sequence
2
  from langchain_core.messages import BaseMessage, HumanMessage
3
  from langchain_core.tools import tool
@@ -10,14 +11,11 @@ from langchain.agents import create_tool_calling_agent
10
  from langchain.agents import AgentExecutor
11
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
12
  import operator
13
- import json
14
 
15
- # Define the agent state
16
  class AgentState(TypedDict):
17
  messages: Annotated[Sequence[BaseMessage], operator.add]
18
  sender: str
19
 
20
- # Initialize tools
21
  @tool
22
  def wikipedia_search(query: str) -> str:
23
  """Search Wikipedia for information."""
@@ -31,75 +29,50 @@ def web_search(query: str, num_results: int = 3) -> list:
31
  @tool
32
  def calculate(expression: str) -> str:
33
  """Evaluate mathematical expressions."""
34
- from langchain_experimental.tools import PythonREPLTool
35
- python_repl = PythonREPLTool()
36
- return python_repl.run(f"print({expression})")
37
 
38
- class AdvancedAIAgent:
39
- def __init__(self, model_name="gpt-4-turbo"):
40
- # Initialize tools and LLM
41
- self.tools = [wikipedia_search, web_search, calculate]
42
- self.llm = ChatOpenAI(model=model_name, temperature=0.7)
43
-
44
- # Create the agent
45
- self.agent = self._create_agent()
46
-
47
- # Build the graph workflow
48
- self.workflow = self._build_graph()
49
 
50
- def _create_agent(self) -> AgentExecutor:
51
- """Create the agent with tools and prompt"""
52
- prompt = ChatPromptTemplate.from_messages([
53
- ("system", "You are a helpful AI assistant. Use tools when needed."),
54
- MessagesPlaceholder(variable_name="messages"),
55
- MessagesPlaceholder(variable_name="agent_scratchpad"),
56
- ])
57
-
58
- agent = create_tool_calling_agent(self.llm, self.tools, prompt)
59
- return AgentExecutor(agent=agent, tools=self.tools, verbose=True)
60
-
61
- def _build_graph(self):
62
- """Build the LangGraph workflow"""
63
- workflow = StateGraph(AgentState)
64
-
65
- # Define nodes
66
- workflow.add_node("agent", self._call_agent)
67
- workflow.add_node("tools", ToolNode(self.tools))
68
-
69
- # Define edges
70
- workflow.set_entry_point("agent")
71
- workflow.add_conditional_edges(
72
- "agent",
73
- self._should_continue,
74
- {
75
- "continue": "tools",
76
- "end": END
77
- }
78
- )
79
- workflow.add_edge("tools", "agent")
80
-
81
- return workflow.compile()
82
-
83
- def _call_agent(self, state: AgentState):
84
- """Execute the agent"""
85
- response = self.agent.invoke({"messages": state["messages"]})
86
- return {"messages": [response["output"]]}
87
-
88
- def _should_continue(self, state: AgentState):
89
- """Determine if the workflow should continue"""
90
- last_message = state["messages"][-1]
91
-
92
- # If no tool calls, end
93
- if not last_message.additional_kwargs.get("tool_calls"):
94
- return "end"
95
- return "continue"
96
 
 
 
 
 
 
 
 
97
  def __call__(self, query: str) -> dict:
98
  """Process a user query"""
99
- # Initialize state
100
  state = AgentState(messages=[HumanMessage(content=query)], sender="user")
101
 
102
- # Execute the workflow
103
  for output in self.workflow.stream(state):
104
  for key, value in output.items():
105
  if key == "messages":
@@ -110,20 +83,25 @@ class AdvancedAIAgent:
110
  "sources": self._extract_sources(state["messages"]),
111
  "steps": self._extract_steps(state["messages"])
112
  }
113
-
 
114
  def _extract_sources(self, messages: Sequence[BaseMessage]) -> list:
115
- """Extract sources from tool messages"""
116
  return [
117
  f"{msg.additional_kwargs.get('name', 'unknown')}: {msg.content}"
118
  for msg in messages
119
  if hasattr(msg, 'additional_kwargs') and 'name' in msg.additional_kwargs
120
  ]
121
-
122
  def _extract_steps(self, messages: Sequence[BaseMessage]) -> list:
123
- """Extract reasoning steps"""
124
  steps = []
125
  for msg in messages:
126
  if hasattr(msg, 'additional_kwargs') and 'tool_calls' in msg.additional_kwargs:
127
  for call in msg.additional_kwargs['tool_calls']:
128
  steps.append(f"Used {call['function']['name']}: {call['function']['arguments']}")
129
- return steps
 
 
 
 
 
 
 
1
+ # agent.py
2
  from typing import TypedDict, Annotated, Sequence
3
  from langchain_core.messages import BaseMessage, HumanMessage
4
  from langchain_core.tools import tool
 
11
  from langchain.agents import AgentExecutor
12
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
13
  import operator
 
14
 
 
15
  class AgentState(TypedDict):
16
  messages: Annotated[Sequence[BaseMessage], operator.add]
17
  sender: str
18
 
 
19
  @tool
20
  def wikipedia_search(query: str) -> str:
21
  """Search Wikipedia for information."""
 
29
  @tool
30
  def calculate(expression: str) -> str:
31
  """Evaluate mathematical expressions."""
32
+ from langchain_experimental.utilities import PythonREPL
33
+ python_repl = PythonREPL()
34
+ return python_repl.run(expression)
35
 
36
+ def build_agent(tools: list, llm: ChatOpenAI) -> AgentExecutor:
37
+ """Build the agent executor"""
38
+ prompt = ChatPromptTemplate.from_messages([
39
+ ("system", "You are a helpful AI assistant. Use tools when needed."),
40
+ MessagesPlaceholder(variable_name="messages"),
41
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
42
+ ])
43
+ agent = create_tool_calling_agent(llm, tools, prompt)
44
+ return AgentExecutor(agent=agent, tools=tools, verbose=True)
 
 
45
 
46
+ def build_graph(tools: list, agent: AgentExecutor) -> StateGraph:
47
+ """Build the LangGraph workflow"""
48
+ workflow = StateGraph(AgentState)
49
+
50
+ # Define nodes
51
+ workflow.add_node("agent", lambda state: {"messages": [agent.invoke(state)["output"]]})
52
+ workflow.add_node("tools", ToolNode(tools))
53
+
54
+ # Define edges
55
+ workflow.set_entry_point("agent")
56
+ workflow.add_conditional_edges(
57
+ "agent",
58
+ lambda state: "continue" if state["messages"][-1].additional_kwargs.get("tool_calls") else "end",
59
+ {"continue": "tools", "end": END}
60
+ )
61
+ workflow.add_edge("tools", "agent")
62
+
63
+ return workflow.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ class AIAgent:
66
+ def __init__(self, model_name: str = "gpt-4-turbo"):
67
+ self.tools = [wikipedia_search, web_search, calculate]
68
+ self.llm = ChatOpenAI(model=model_name, temperature=0.7)
69
+ self.agent = build_agent(self.tools, self.llm)
70
+ self.workflow = build_graph(self.tools, self.agent)
71
+
72
  def __call__(self, query: str) -> dict:
73
  """Process a user query"""
 
74
  state = AgentState(messages=[HumanMessage(content=query)], sender="user")
75
 
 
76
  for output in self.workflow.stream(state):
77
  for key, value in output.items():
78
  if key == "messages":
 
83
  "sources": self._extract_sources(state["messages"]),
84
  "steps": self._extract_steps(state["messages"])
85
  }
86
+ return {"response": "No response generated", "sources": [], "steps": []}
87
+
88
  def _extract_sources(self, messages: Sequence[BaseMessage]) -> list:
 
89
  return [
90
  f"{msg.additional_kwargs.get('name', 'unknown')}: {msg.content}"
91
  for msg in messages
92
  if hasattr(msg, 'additional_kwargs') and 'name' in msg.additional_kwargs
93
  ]
94
+
95
  def _extract_steps(self, messages: Sequence[BaseMessage]) -> list:
 
96
  steps = []
97
  for msg in messages:
98
  if hasattr(msg, 'additional_kwargs') and 'tool_calls' in msg.additional_kwargs:
99
  for call in msg.additional_kwargs['tool_calls']:
100
  steps.append(f"Used {call['function']['name']}: {call['function']['arguments']}")
101
+ return steps
102
+
103
+ # Example usage
104
+ if __name__ == "__main__":
105
+ agent = AIAgent()
106
+ response = agent("What's the capital of France?")
107
+ print(response)