Spaces:
Restarting
Restarting
Update agent.py
Browse files
agent.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# agent.py
|
2 |
import os
|
3 |
from dotenv import load_dotenv
|
4 |
-
from typing import TypedDict, Annotated, Sequence
|
5 |
from langchain_core.messages import BaseMessage, HumanMessage
|
6 |
from langchain_core.tools import tool
|
7 |
from langchain_openai import ChatOpenAI
|
@@ -9,8 +9,7 @@ from langgraph.graph import END, StateGraph
|
|
9 |
from langgraph.prebuilt import ToolNode
|
10 |
from langchain_community.tools import DuckDuckGoSearchResults
|
11 |
from langchain_community.utilities import WikipediaAPIWrapper
|
12 |
-
from langchain.agents import create_tool_calling_agent
|
13 |
-
from langchain.agents import AgentExecutor
|
14 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
15 |
import operator
|
16 |
|
@@ -37,43 +36,53 @@ def calculate(expression: str) -> str:
|
|
37 |
python_repl = PythonREPL()
|
38 |
return python_repl.run(expression)
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
])
|
47 |
-
agent = create_tool_calling_agent(llm, tools, prompt)
|
48 |
-
return AgentExecutor(agent=agent, tools=tools, verbose=True)
|
49 |
-
|
50 |
-
def build_graph(tools: list, agent: AgentExecutor) -> StateGraph:
|
51 |
-
"""Build the LangGraph workflow"""
|
52 |
-
workflow = StateGraph(AgentState)
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
self.tools = [wikipedia_search, web_search, calculate]
|
72 |
-
self.llm = ChatOpenAI(model=model_name, temperature=0.7)
|
73 |
-
self.agent = build_agent(self.tools, self.llm)
|
74 |
-
self.workflow = build_graph(self.tools, self.agent)
|
75 |
|
76 |
-
def
|
|
|
|
|
|
|
|
|
|
|
77 |
"""Process a user query"""
|
78 |
state = AgentState(messages=[HumanMessage(content=query)], sender="user")
|
79 |
|
@@ -89,23 +98,19 @@ class AdvancedAIAgent:
|
|
89 |
}
|
90 |
return {"response": "No response generated", "sources": [], "steps": []}
|
91 |
|
92 |
-
def _extract_sources(self, messages: Sequence[BaseMessage]) ->
|
|
|
93 |
return [
|
94 |
f"{msg.additional_kwargs.get('name', 'unknown')}: {msg.content}"
|
95 |
for msg in messages
|
96 |
if hasattr(msg, 'additional_kwargs') and 'name' in msg.additional_kwargs
|
97 |
]
|
98 |
|
99 |
-
def _extract_steps(self, messages: Sequence[BaseMessage]) ->
|
|
|
100 |
steps = []
|
101 |
for msg in messages:
|
102 |
if hasattr(msg, 'additional_kwargs') and 'tool_calls' in msg.additional_kwargs:
|
103 |
for call in msg.additional_kwargs['tool_calls']:
|
104 |
steps.append(f"Used {call['function']['name']}: {call['function']['arguments']}")
|
105 |
-
return steps
|
106 |
-
|
107 |
-
# Example usage
|
108 |
-
if __name__ == "__main__":
|
109 |
-
agent = AdvancedAIAgent()
|
110 |
-
response = agent("What's the capital of France?")
|
111 |
-
print(response)
|
|
|
1 |
# agent.py
|
2 |
import os
|
3 |
from dotenv import load_dotenv
|
4 |
+
from typing import TypedDict, Annotated, Sequence, Dict, Any
|
5 |
from langchain_core.messages import BaseMessage, HumanMessage
|
6 |
from langchain_core.tools import tool
|
7 |
from langchain_openai import ChatOpenAI
|
|
|
9 |
from langgraph.prebuilt import ToolNode
|
10 |
from langchain_community.tools import DuckDuckGoSearchResults
|
11 |
from langchain_community.utilities import WikipediaAPIWrapper
|
12 |
+
from langchain.agents import create_tool_calling_agent, AgentExecutor
|
|
|
13 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
14 |
import operator
|
15 |
|
|
|
36 |
python_repl = PythonREPL()
|
37 |
return python_repl.run(expression)
|
38 |
|
39 |
+
class AIAgent:
|
40 |
+
def __init__(self, model_name: str = "gpt-3.5-turbo"):
|
41 |
+
self.tools = [wikipedia_search, web_search, calculate]
|
42 |
+
self.llm = ChatOpenAI(model=model_name, temperature=0.7)
|
43 |
+
self.agent_executor = self._build_agent_executor()
|
44 |
+
self.workflow = self._build_workflow() # Changed from 'graph' to 'workflow'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
def _build_agent_executor(self) -> AgentExecutor:
|
47 |
+
"""Build the agent executor"""
|
48 |
+
prompt = ChatPromptTemplate.from_messages([
|
49 |
+
("system", "You are a helpful AI assistant. Use tools when needed."),
|
50 |
+
MessagesPlaceholder(variable_name="messages"),
|
51 |
+
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
52 |
+
])
|
53 |
+
agent = create_tool_calling_agent(self.llm, self.tools, prompt)
|
54 |
+
return AgentExecutor(agent=agent, tools=self.tools, verbose=True)
|
55 |
|
56 |
+
def _build_workflow(self) -> StateGraph:
|
57 |
+
"""Build and return the compiled workflow"""
|
58 |
+
workflow = StateGraph(AgentState)
|
59 |
+
|
60 |
+
# Define nodes
|
61 |
+
workflow.add_node("agent", self._run_agent)
|
62 |
+
workflow.add_node("tools", ToolNode(self.tools))
|
63 |
+
|
64 |
+
# Define edges
|
65 |
+
workflow.set_entry_point("agent")
|
66 |
+
workflow.add_conditional_edges(
|
67 |
+
"agent",
|
68 |
+
self._should_continue,
|
69 |
+
{"continue": "tools", "end": END}
|
70 |
+
)
|
71 |
+
workflow.add_edge("tools", "agent")
|
72 |
+
|
73 |
+
return workflow.compile()
|
74 |
|
75 |
+
def _run_agent(self, state: AgentState) -> Dict[str, Any]:
|
76 |
+
"""Execute the agent"""
|
77 |
+
response = self.agent_executor.invoke({"messages": state["messages"]})
|
78 |
+
return {"messages": [response["output"]]}
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
def _should_continue(self, state: AgentState) -> str:
|
81 |
+
"""Determine if the workflow should continue"""
|
82 |
+
last_message = state["messages"][-1]
|
83 |
+
return "continue" if last_message.additional_kwargs.get("tool_calls") else "end"
|
84 |
+
|
85 |
+
def __call__(self, query: str) -> Dict[str, Any]:
|
86 |
"""Process a user query"""
|
87 |
state = AgentState(messages=[HumanMessage(content=query)], sender="user")
|
88 |
|
|
|
98 |
}
|
99 |
return {"response": "No response generated", "sources": [], "steps": []}
|
100 |
|
101 |
+
def _extract_sources(self, messages: Sequence[BaseMessage]) -> List[str]:
|
102 |
+
"""Extract sources from tool messages"""
|
103 |
return [
|
104 |
f"{msg.additional_kwargs.get('name', 'unknown')}: {msg.content}"
|
105 |
for msg in messages
|
106 |
if hasattr(msg, 'additional_kwargs') and 'name' in msg.additional_kwargs
|
107 |
]
|
108 |
|
109 |
+
def _extract_steps(self, messages: Sequence[BaseMessage]) -> List[str]:
|
110 |
+
"""Extract reasoning steps"""
|
111 |
steps = []
|
112 |
for msg in messages:
|
113 |
if hasattr(msg, 'additional_kwargs') and 'tool_calls' in msg.additional_kwargs:
|
114 |
for call in msg.additional_kwargs['tool_calls']:
|
115 |
steps.append(f"Used {call['function']['name']}: {call['function']['arguments']}")
|
116 |
+
return steps
|
|
|
|
|
|
|
|
|
|
|
|