wt002 commited on
Commit
9af2eae
·
verified ·
1 Parent(s): 03aebad

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +50 -45
agent.py CHANGED
@@ -1,7 +1,7 @@
1
  # agent.py
2
  import os
3
  from dotenv import load_dotenv
4
- from typing import TypedDict, Annotated, Sequence
5
  from langchain_core.messages import BaseMessage, HumanMessage
6
  from langchain_core.tools import tool
7
  from langchain_openai import ChatOpenAI
@@ -9,8 +9,7 @@ from langgraph.graph import END, StateGraph
9
  from langgraph.prebuilt import ToolNode
10
  from langchain_community.tools import DuckDuckGoSearchResults
11
  from langchain_community.utilities import WikipediaAPIWrapper
12
- from langchain.agents import create_tool_calling_agent
13
- from langchain.agents import AgentExecutor
14
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
15
  import operator
16
 
@@ -37,43 +36,53 @@ def calculate(expression: str) -> str:
37
  python_repl = PythonREPL()
38
  return python_repl.run(expression)
39
 
40
- def build_agent(tools: list, llm: ChatOpenAI) -> AgentExecutor:
41
- """Build the agent executor"""
42
- prompt = ChatPromptTemplate.from_messages([
43
- ("system", "You are a helpful AI assistant. Use tools when needed."),
44
- MessagesPlaceholder(variable_name="messages"),
45
- MessagesPlaceholder(variable_name="agent_scratchpad"),
46
- ])
47
- agent = create_tool_calling_agent(llm, tools, prompt)
48
- return AgentExecutor(agent=agent, tools=tools, verbose=True)
49
-
50
- def build_graph(tools: list, agent: AgentExecutor) -> StateGraph:
51
- """Build the LangGraph workflow"""
52
- workflow = StateGraph(AgentState)
53
 
54
- # Define nodes
55
- workflow.add_node("agent", lambda state: {"messages": [agent.invoke(state)["output"]]})
56
- workflow.add_node("tools", ToolNode(tools))
 
 
 
 
 
 
57
 
58
- # Define edges
59
- workflow.set_entry_point("agent")
60
- workflow.add_conditional_edges(
61
- "agent",
62
- lambda state: "continue" if state["messages"][-1].additional_kwargs.get("tool_calls") else "end",
63
- {"continue": "tools", "end": END}
64
- )
65
- workflow.add_edge("tools", "agent")
 
 
 
 
 
 
 
 
 
 
66
 
67
- return workflow.compile()
68
-
69
- class AdvancedAIAgent:
70
- def __init__(self, model_name: str = "gpt-4-turbo"):
71
- self.tools = [wikipedia_search, web_search, calculate]
72
- self.llm = ChatOpenAI(model=model_name, temperature=0.7)
73
- self.agent = build_agent(self.tools, self.llm)
74
- self.workflow = build_graph(self.tools, self.agent)
75
 
76
- def __call__(self, query: str) -> dict:
 
 
 
 
 
77
  """Process a user query"""
78
  state = AgentState(messages=[HumanMessage(content=query)], sender="user")
79
 
@@ -89,23 +98,19 @@ class AdvancedAIAgent:
89
  }
90
  return {"response": "No response generated", "sources": [], "steps": []}
91
 
92
- def _extract_sources(self, messages: Sequence[BaseMessage]) -> list:
 
93
  return [
94
  f"{msg.additional_kwargs.get('name', 'unknown')}: {msg.content}"
95
  for msg in messages
96
  if hasattr(msg, 'additional_kwargs') and 'name' in msg.additional_kwargs
97
  ]
98
 
99
- def _extract_steps(self, messages: Sequence[BaseMessage]) -> list:
 
100
  steps = []
101
  for msg in messages:
102
  if hasattr(msg, 'additional_kwargs') and 'tool_calls' in msg.additional_kwargs:
103
  for call in msg.additional_kwargs['tool_calls']:
104
  steps.append(f"Used {call['function']['name']}: {call['function']['arguments']}")
105
- return steps
106
-
107
- # Example usage
108
- if __name__ == "__main__":
109
- agent = AdvancedAIAgent()
110
- response = agent("What's the capital of France?")
111
- print(response)
 
1
  # agent.py
2
  import os
3
  from dotenv import load_dotenv
4
+ from typing import TypedDict, Annotated, Sequence, Dict, Any
5
  from langchain_core.messages import BaseMessage, HumanMessage
6
  from langchain_core.tools import tool
7
  from langchain_openai import ChatOpenAI
 
9
  from langgraph.prebuilt import ToolNode
10
  from langchain_community.tools import DuckDuckGoSearchResults
11
  from langchain_community.utilities import WikipediaAPIWrapper
12
+ from langchain.agents import create_tool_calling_agent, AgentExecutor
 
13
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
14
  import operator
15
 
 
36
  python_repl = PythonREPL()
37
  return python_repl.run(expression)
38
 
39
+ class AIAgent:
40
+ def __init__(self, model_name: str = "gpt-3.5-turbo"):
41
+ self.tools = [wikipedia_search, web_search, calculate]
42
+ self.llm = ChatOpenAI(model=model_name, temperature=0.7)
43
+ self.agent_executor = self._build_agent_executor()
44
+ self.workflow = self._build_workflow() # Changed from 'graph' to 'workflow'
 
 
 
 
 
 
 
45
 
46
+ def _build_agent_executor(self) -> AgentExecutor:
47
+ """Build the agent executor"""
48
+ prompt = ChatPromptTemplate.from_messages([
49
+ ("system", "You are a helpful AI assistant. Use tools when needed."),
50
+ MessagesPlaceholder(variable_name="messages"),
51
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
52
+ ])
53
+ agent = create_tool_calling_agent(self.llm, self.tools, prompt)
54
+ return AgentExecutor(agent=agent, tools=self.tools, verbose=True)
55
 
56
+ def _build_workflow(self) -> StateGraph:
57
+ """Build and return the compiled workflow"""
58
+ workflow = StateGraph(AgentState)
59
+
60
+ # Define nodes
61
+ workflow.add_node("agent", self._run_agent)
62
+ workflow.add_node("tools", ToolNode(self.tools))
63
+
64
+ # Define edges
65
+ workflow.set_entry_point("agent")
66
+ workflow.add_conditional_edges(
67
+ "agent",
68
+ self._should_continue,
69
+ {"continue": "tools", "end": END}
70
+ )
71
+ workflow.add_edge("tools", "agent")
72
+
73
+ return workflow.compile()
74
 
75
+ def _run_agent(self, state: AgentState) -> Dict[str, Any]:
76
+ """Execute the agent"""
77
+ response = self.agent_executor.invoke({"messages": state["messages"]})
78
+ return {"messages": [response["output"]]}
 
 
 
 
79
 
80
+ def _should_continue(self, state: AgentState) -> str:
81
+ """Determine if the workflow should continue"""
82
+ last_message = state["messages"][-1]
83
+ return "continue" if last_message.additional_kwargs.get("tool_calls") else "end"
84
+
85
+ def __call__(self, query: str) -> Dict[str, Any]:
86
  """Process a user query"""
87
  state = AgentState(messages=[HumanMessage(content=query)], sender="user")
88
 
 
98
  }
99
  return {"response": "No response generated", "sources": [], "steps": []}
100
 
101
+ def _extract_sources(self, messages: Sequence[BaseMessage]) -> List[str]:
102
+ """Extract sources from tool messages"""
103
  return [
104
  f"{msg.additional_kwargs.get('name', 'unknown')}: {msg.content}"
105
  for msg in messages
106
  if hasattr(msg, 'additional_kwargs') and 'name' in msg.additional_kwargs
107
  ]
108
 
109
+ def _extract_steps(self, messages: Sequence[BaseMessage]) -> List[str]:
110
+ """Extract reasoning steps"""
111
  steps = []
112
  for msg in messages:
113
  if hasattr(msg, 'additional_kwargs') and 'tool_calls' in msg.additional_kwargs:
114
  for call in msg.additional_kwargs['tool_calls']:
115
  steps.append(f"Used {call['function']['name']}: {call['function']['arguments']}")
116
+ return steps