wt002 commited on
Commit
a52ceb6
·
verified ·
1 Parent(s): b49b95b

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +83 -192
agent.py CHANGED
@@ -1,213 +1,104 @@
1
  # agent.py
2
  import os
3
- from dotenv import load_dotenv
4
- from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition
6
- from langgraph.prebuilt import ToolNode
7
- #from langchain_google_genai import ChatGoogleGenerativeAI
8
- from langchain_groq import ChatGroq
9
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
- from langchain_community.tools.tavily_search import TavilySearchResults
11
- from langchain_community.document_loaders import WikipediaLoader
12
- from langchain_community.document_loaders import ArxivLoader
13
- from langchain_community.vectorstores import SupabaseVectorStore
14
- from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
- from langchain.tools.retriever import create_retriever_tool
17
- from supabase.client import Client, create_client
 
 
 
 
 
 
 
18
 
19
  load_dotenv()
20
 
21
- @tool
22
- def multiply(a: int, b: int) -> int:
23
- """Multiply two numbers.
24
- Args:
25
- a: first int
26
- b: second int
27
- """
28
- return a * b
29
 
30
- @tool
31
- def add(a: int, b: int) -> int:
32
- """Add two numbers.
33
-
34
- Args:
35
- a: first int
36
- b: second int
37
- """
38
- return a + b
39
 
40
  @tool
41
- def subtract(a: int, b: int) -> int:
42
- """Subtract two numbers.
43
-
44
- Args:
45
- a: first int
46
- b: second int
47
- """
48
- return a - b
49
 
50
  @tool
51
- def divide(a: int, b: int) -> int:
52
- """Divide two numbers.
53
-
54
- Args:
55
- a: first int
56
- b: second int
57
- """
58
- if b == 0:
59
- raise ValueError("Cannot divide by zero.")
60
- return a / b
61
 
62
  @tool
63
- def modulus(a: int, b: int) -> int:
64
- """Get the modulus of two numbers.
65
-
66
- Args:
67
- a: first int
68
- b: second int
69
- """
70
- return a % b
71
 
72
- @tool
73
- def wiki_search(query: str) -> str:
74
- """Search Wikipedia for a query and return maximum 2 results.
75
 
76
- Args:
77
- query: The search query."""
78
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
79
- formatted_search_docs = "\n\n---\n\n".join(
80
- [
81
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
82
- for doc in search_docs
83
- ])
84
- return {"wiki_results": formatted_search_docs}
85
-
86
- @tool
87
- def web_search(query: str) -> str:
88
- """Search Tavily for a query and return maximum 3 results.
89
 
90
- Args:
91
- query: The search query."""
92
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
93
- formatted_search_docs = "\n\n---\n\n".join(
94
- [
95
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
96
- for doc in search_docs
97
  ])
98
- return {"web_results": formatted_search_docs}
99
-
100
- @tool
101
- def arvix_search(query: str) -> str:
102
- """Search Arxiv for a query and return maximum 3 result.
103
 
104
- Args:
105
- query: The search query."""
106
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
107
- formatted_search_docs = "\n\n---\n\n".join(
108
- [
109
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
110
- for doc in search_docs
111
- ])
112
- return {"arvix_results": formatted_search_docs}
113
-
114
-
115
-
116
- # load the system prompt from the file
117
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
- system_prompt = f.read()
119
-
120
- # System message
121
- sys_msg = SystemMessage(content=system_prompt)
122
-
123
- # build a retriever
124
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
125
- supabase: Client = create_client(
126
- os.environ.get("SUPABASE_URL"),
127
- os.environ.get("SUPABASE_SERVICE_KEY"))
128
- vector_store = SupabaseVectorStore(
129
- client=supabase,
130
- embedding= embeddings,
131
- table_name="documents",
132
- query_name="match_documents_langchain",
133
- )
134
- create_retriever_tool = create_retriever_tool(
135
- retriever=vector_store.as_retriever(),
136
- name="Question Search",
137
- description="A tool to retrieve similar questions from a vector store.",
138
- )
139
-
140
-
141
-
142
- tools = [
143
- multiply,
144
- add,
145
- subtract,
146
- divide,
147
- modulus,
148
- wiki_search,
149
- web_search,
150
- arvix_search,
151
- ]
152
-
153
- # Build graph function
154
- def build_graph(provider: str = "groq"):
155
- """Build the graph"""
156
- # Load environment variables from .env file
157
- if provider == "google":
158
- # Google Gemini
159
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
160
- elif provider == "groq":
161
- # Groq https://console.groq.com/docs/models
162
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
163
- elif provider == "huggingface":
164
- # TODO: Add huggingface endpoint
165
- llm = ChatHuggingFace(
166
- llm=HuggingFaceEndpoint(
167
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
168
- temperature=0,
169
- ),
170
  )
171
- else:
172
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
173
- # Bind tools to LLM
174
- llm_with_tools = llm.bind_tools(tools)
175
-
176
- # Node
177
- def assistant(state: MessagesState):
178
- """Assistant node"""
179
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
180
 
181
- def retriever(state: MessagesState):
182
- """Retriever node"""
183
- similar_question = vector_store.similarity_search(state["messages"][0].content)
184
- example_msg = HumanMessage(
185
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
186
- )
187
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
188
-
189
- builder = StateGraph(MessagesState)
190
- builder.add_node("retriever", retriever)
191
- builder.add_node("assistant", assistant)
192
- builder.add_node("tools", ToolNode(tools))
193
- builder.add_edge(START, "retriever")
194
- builder.add_edge("retriever", "assistant")
195
- builder.add_conditional_edges(
196
- "assistant",
197
- tools_condition,
198
- )
199
- builder.add_edge("tools", "assistant")
200
-
201
- # Compile graph
202
- return builder.compile()
203
-
204
- # test
 
 
205
  if __name__ == "__main__":
206
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
207
- # Build the graph
208
- graph = build_graph(provider="groq")
209
- # Run the graph
210
- messages = [HumanMessage(content=question)]
211
- messages = graph.invoke({"messages": messages})
212
- for m in messages["messages"]:
213
- m.pretty_print()
 
1
  # agent.py
2
  import os
3
+ from typing import TypedDict, Annotated, Sequence, Dict, Any, List
4
+ from langchain_core.messages import BaseMessage, HumanMessage
 
 
 
 
 
 
 
 
 
 
5
  from langchain_core.tools import tool
6
+ from langchain_openai import ChatOpenAI
7
+ from langgraph.graph import END, StateGraph
8
+ from langgraph.prebuilt import ToolNode
9
+ from langchain_community.tools import DuckDuckGoSearchResults
10
+ from langchain_community.utilities import WikipediaAPIWrapper
11
+ from langchain.agents import create_tool_calling_agent, AgentExecutor
12
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
13
+ import operator
14
+ from langchain_experimental.utilities import PythonREPL
15
 
16
  load_dotenv()
17
 
 
 
 
 
 
 
 
 
18
 
19
+ class AgentState(TypedDict):
20
+ messages: Annotated[Sequence[BaseMessage], operator.add]
21
+ sender: str
 
 
 
 
 
 
22
 
23
  @tool
24
+ def wikipedia_search(query: str) -> str:
25
+ """Search Wikipedia for information."""
26
+ return WikipediaAPIWrapper().run(query)
 
 
 
 
 
27
 
28
  @tool
29
+ def web_search(query: str, num_results: int = 3) -> list:
30
+ """Search the web for current information."""
31
+ return DuckDuckGoSearchResults(num_results=num_results).run(query)
 
 
 
 
 
 
 
32
 
33
  @tool
34
+ def calculate(expression: str) -> str:
35
+ """Evaluate mathematical expressions."""
36
+ python_repl = PythonREPL()
37
+ return python_repl.run(expression)
 
 
 
 
38
 
39
+ class BasicAgent:
40
+ """A complete langgraph agent implementation."""
 
41
 
42
+ def __init__(self, model_name: str = "gpt-3.5-turbo"):
43
+ self.tools = [wikipedia_search, web_search, calculate]
44
+ self.llm = ChatOpenAI(model=model_name, temperature=0.7)
45
+ self.agent_executor = self._build_agent_executor()
46
+ self.workflow = self._build_workflow() # Initialize workflow here
 
 
 
 
 
 
 
 
47
 
48
+ def _build_agent_executor(self) -> AgentExecutor:
49
+ """Build the agent executor with tools."""
50
+ prompt = ChatPromptTemplate.from_messages([
51
+ ("system", "You are a helpful AI assistant. Use tools when needed."),
52
+ MessagesPlaceholder(variable_name="messages"),
53
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
 
54
  ])
55
+ agent = create_tool_calling_agent(self.llm, self.tools, prompt)
56
+ return AgentExecutor(agent=agent, tools=self.tools, verbose=True)
 
 
 
57
 
58
+ def _build_workflow(self) -> StateGraph:
59
+ """Build and compile the agent workflow."""
60
+ workflow = StateGraph(AgentState)
61
+
62
+ workflow.add_node("agent", self._run_agent)
63
+ workflow.add_node("tools", ToolNode(self.tools))
64
+
65
+ workflow.set_entry_point("agent")
66
+ workflow.add_conditional_edges(
67
+ "agent",
68
+ self._should_continue,
69
+ {"continue": "tools", "end": END}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  )
71
+ workflow.add_edge("tools", "agent")
72
+
73
+ return workflow.compile()
 
 
 
 
 
 
74
 
75
+ def _run_agent(self, state: AgentState) -> Dict[str, Any]:
76
+ """Execute the agent."""
77
+ response = self.agent_executor.invoke({"messages": state["messages"]})
78
+ return {"messages": [response["output"]]}
79
+
80
+ def _should_continue(self, state: AgentState) -> str:
81
+ """Determine if the workflow should continue."""
82
+ last_message = state["messages"][-1]
83
+ return "continue" if last_message.additional_kwargs.get("tool_calls") else "end"
84
+
85
+ def __call__(self, question: str) -> str:
86
+ """Process a user question and return a response."""
87
+ # Initialize state with the user's question
88
+ state = AgentState(messages=[HumanMessage(content=question)], sender="user")
89
+
90
+ # Execute the workflow
91
+ for output in self.workflow.stream(state):
92
+ for key, value in output.items():
93
+ if key == "messages":
94
+ for message in value:
95
+ if isinstance(message, BaseMessage):
96
+ return message.content
97
+
98
+ return "Sorry, I couldn't generate a response."
99
+
100
+ # Example usage
101
  if __name__ == "__main__":
102
+ agent = BasicAgent()
103
+ response = agent("What's the capital of France?")
104
+ print(response)