wt002 commited on
Commit
26f5620
·
verified ·
1 Parent(s): 92047ee

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +203 -84
agent.py CHANGED
@@ -1,105 +1,224 @@
1
  # agent.py
 
2
  import os
3
  from dotenv import load_dotenv
4
- from typing import TypedDict, Annotated, Sequence, Dict, Any, List
5
- from langchain_core.messages import BaseMessage, HumanMessage
6
- from langchain_core.tools import tool
7
- from langchain_openai import ChatOpenAI
8
- from langgraph.graph import END, StateGraph
9
  from langgraph.prebuilt import ToolNode
10
- from langchain_community.tools import DuckDuckGoSearchResults
11
- from langchain_community.utilities import WikipediaAPIWrapper
12
- from langchain.agents import create_tool_calling_agent, AgentExecutor
13
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
14
- import operator
15
- from langchain_experimental.utilities import PythonREPL
 
 
 
 
 
16
 
17
  load_dotenv()
18
 
 
 
 
 
 
 
 
 
19
 
20
- class AgentState(TypedDict):
21
- messages: Annotated[Sequence[BaseMessage], operator.add]
22
- sender: str
 
 
 
 
 
 
23
 
24
  @tool
25
- def wikipedia_search(query: str) -> str:
26
- """Search Wikipedia for information."""
27
- return WikipediaAPIWrapper().run(query)
 
 
 
 
 
28
 
29
  @tool
30
- def web_search(query: str, num_results: int = 3) -> list:
31
- """Search the web for current information."""
32
- return DuckDuckGoSearchResults(num_results=num_results).run(query)
 
 
 
 
 
 
 
33
 
34
  @tool
35
- def calculate(expression: str) -> str:
36
- """Evaluate mathematical expressions."""
37
- python_repl = PythonREPL()
38
- return python_repl.run(expression)
 
 
 
 
39
 
40
- class BasicAgent:
41
- """A complete langgraph agent implementation."""
 
42
 
43
- def __init__(self, model_name: str = "gpt-3.5-turbo"):
44
- self.tools = [wikipedia_search, web_search, calculate]
45
- self.llm = ChatOpenAI(model=model_name, temperature=0.7)
46
- self.agent_executor = self._build_agent_executor()
47
- self.workflow = self._build_workflow() # Initialize workflow here
 
 
 
 
 
 
 
 
48
 
49
- def _build_agent_executor(self) -> AgentExecutor:
50
- """Build the agent executor with tools."""
51
- prompt = ChatPromptTemplate.from_messages([
52
- ("system", "You are a helpful AI assistant. Use tools when needed."),
53
- MessagesPlaceholder(variable_name="messages"),
54
- MessagesPlaceholder(variable_name="agent_scratchpad"),
 
55
  ])
56
- agent = create_tool_calling_agent(self.llm, self.tools, prompt)
57
- return AgentExecutor(agent=agent, tools=self.tools, verbose=True)
 
 
 
58
 
59
- def _build_workflow(self) -> StateGraph:
60
- """Build and compile the agent workflow."""
61
- workflow = StateGraph(AgentState)
62
-
63
- workflow.add_node("agent", self._run_agent)
64
- workflow.add_node("tools", ToolNode(self.tools))
65
-
66
- workflow.set_entry_point("agent")
67
- workflow.add_conditional_edges(
68
- "agent",
69
- self._should_continue,
70
- {"continue": "tools", "end": END}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  )
72
- workflow.add_edge("tools", "agent")
73
-
74
- return workflow.compile()
75
-
76
- def _run_agent(self, state: AgentState) -> Dict[str, Any]:
77
- """Execute the agent."""
78
- response = self.agent_executor.invoke({"messages": state["messages"]})
79
- return {"messages": [response["output"]]}
80
-
81
- def _should_continue(self, state: AgentState) -> str:
82
- """Determine if the workflow should continue."""
83
- last_message = state["messages"][-1]
84
- return "continue" if last_message.additional_kwargs.get("tool_calls") else "end"
85
 
86
- def __call__(self, question: str) -> str:
87
- """Process a user question and return a response."""
88
- # Initialize state with the user's question
89
- state = AgentState(messages=[HumanMessage(content=question)], sender="user")
90
-
91
- # Execute the workflow
92
- for output in self.workflow.stream(state):
93
- for key, value in output.items():
94
- if key == "messages":
95
- for message in value:
96
- if isinstance(message, BaseMessage):
97
- return message.content
98
-
99
- return "Sorry, I couldn't generate a response."
100
-
101
- # Example usage
102
- if __name__ == "__main__":
103
- agent = BasicAgent()
104
- response = agent("What's the capital of France?")
105
- print(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # agent.py
2
+
3
  import os
4
  from dotenv import load_dotenv
5
+ from langgraph.graph import START, StateGraph, MessagesState
6
+ from langgraph.prebuilt import tools_condition
 
 
 
7
  from langgraph.prebuilt import ToolNode
8
+ from langchain_google_genai import ChatGoogleGenerativeAI
9
+ from langchain_groq import ChatGroq
10
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
11
+ from langchain_community.tools.tavily_search import TavilySearchResults
12
+ from langchain_community.document_loaders import WikipediaLoader
13
+ from langchain_community.document_loaders import ArxivLoader
14
+ from langchain_community.vectorstores import SupabaseVectorStore
15
+ from langchain_core.messages import SystemMessage, HumanMessage
16
+ from langchain_core.tools import tool
17
+ from langchain.tools.retriever import create_retriever_tool
18
+ from supabase.client import Client, create_client
19
 
20
  load_dotenv()
21
 
22
+ @tool
23
+ def multiply(a: int, b: int) -> int:
24
+ """Multiply two numbers.
25
+ Args:
26
+ a: first int
27
+ b: second int
28
+ """
29
+ return a * b
30
 
31
+ @tool
32
+ def add(a: int, b: int) -> int:
33
+ """Add two numbers.
34
+
35
+ Args:
36
+ a: first int
37
+ b: second int
38
+ """
39
+ return a + b
40
 
41
  @tool
42
+ def subtract(a: int, b: int) -> int:
43
+ """Subtract two numbers.
44
+
45
+ Args:
46
+ a: first int
47
+ b: second int
48
+ """
49
+ return a - b
50
 
51
  @tool
52
+ def divide(a: int, b: int) -> int:
53
+ """Divide two numbers.
54
+
55
+ Args:
56
+ a: first int
57
+ b: second int
58
+ """
59
+ if b == 0:
60
+ raise ValueError("Cannot divide by zero.")
61
+ return a / b
62
 
63
  @tool
64
+ def modulus(a: int, b: int) -> int:
65
+ """Get the modulus of two numbers.
66
+
67
+ Args:
68
+ a: first int
69
+ b: second int
70
+ """
71
+ return a % b
72
 
73
+ @tool
74
+ def wiki_search(query: str) -> str:
75
+ """Search Wikipedia for a query and return maximum 2 results.
76
 
77
+ Args:
78
+ query: The search query."""
79
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
80
+ formatted_search_docs = "\n\n---\n\n".join(
81
+ [
82
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
83
+ for doc in search_docs
84
+ ])
85
+ return {"wiki_results": formatted_search_docs}
86
+
87
+ @tool
88
+ def web_search(query: str) -> str:
89
+ """Search Tavily for a query and return maximum 3 results.
90
 
91
+ Args:
92
+ query: The search query."""
93
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
94
+ formatted_search_docs = "\n\n---\n\n".join(
95
+ [
96
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
97
+ for doc in search_docs
98
  ])
99
+ return {"web_results": formatted_search_docs}
100
+
101
+ @tool
102
+ def arvix_search(query: str) -> str:
103
+ """Search Arxiv for a query and return maximum 3 result.
104
 
105
+ Args:
106
+ query: The search query."""
107
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
108
+ formatted_search_docs = "\n\n---\n\n".join(
109
+ [
110
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
111
+ for doc in search_docs
112
+ ])
113
+ return {"arvix_results": formatted_search_docs}
114
+
115
+
116
+
117
+ # load the system prompt from the file
118
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
119
+ system_prompt = f.read()
120
+
121
+ # System message
122
+ sys_msg = SystemMessage(content=system_prompt)
123
+
124
+ # build a retriever
125
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
126
+ supabase: Client = create_client(
127
+ os.environ.get("SUPABASE_URL"),
128
+ os.environ.get("SUPABASE_SERVICE_KEY"))
129
+ vector_store = SupabaseVectorStore(
130
+ client=supabase,
131
+ embedding= embeddings,
132
+ table_name="documents",
133
+ query_name="match_documents_langchain",
134
+ )
135
+ create_retriever_tool = create_retriever_tool(
136
+ retriever=vector_store.as_retriever(),
137
+ name="Question Search",
138
+ description="A tool to retrieve similar questions from a vector store.",
139
+ )
140
+
141
+
142
+
143
+ tools = [
144
+ multiply,
145
+ add,
146
+ subtract,
147
+ divide,
148
+ modulus,
149
+ wiki_search,
150
+ web_search,
151
+ arvix_search,
152
+ ]
153
+
154
+ # Build graph function
155
+ def build_graph(provider: str = "google"):
156
+ """Build the graph"""
157
+ # Load environment variables from .env file
158
+ if provider == "google":
159
+ # Google Gemini
160
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
161
+ elif provider == "groq":
162
+ # Groq https://console.groq.com/docs/models
163
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
164
+ elif provider == "huggingface":
165
+ # TODO: Add huggingface endpoint
166
+ llm = ChatHuggingFace(
167
+ llm=HuggingFaceEndpoint(
168
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
169
+ temperature=0,
170
+ ),
171
  )
172
+ else:
173
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
174
+ # Bind tools to LLM
175
+ llm_with_tools = llm.bind_tools(tools)
176
+
177
+ # Node
178
+ def assistant(state: MessagesState):
179
+ """Assistant node"""
180
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
 
 
181
 
182
+ # def retriever(state: MessagesState):
183
+ # """Retriever node"""
184
+ # similar_question = vector_store.similarity_search(state["messages"][0].content)
185
+ #example_msg = HumanMessage(
186
+ # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
187
+ # )
188
+ # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
189
+
190
+ from langchain_core.messages import AIMessage
191
+
192
+ def retriever(state: MessagesState):
193
+ query = state["messages"][-1].content
194
+ similar_doc = vector_store.similarity_search(query, k=1)[0]
195
+
196
+ content = similar_doc.page_content
197
+ if "Final answer :" in content:
198
+ answer = content.split("Final answer :")[-1].strip()
199
+ else:
200
+ answer = content.strip()
201
+
202
+ return {"messages": [AIMessage(content=answer)]}
203
+
204
+ # builder = StateGraph(MessagesState)
205
+ #builder.add_node("retriever", retriever)
206
+ #builder.add_node("assistant", assistant)
207
+ #builder.add_node("tools", ToolNode(tools))
208
+ #builder.add_edge(START, "retriever")
209
+ #builder.add_edge("retriever", "assistant")
210
+ #builder.add_conditional_edges(
211
+ # "assistant",
212
+ # tools_condition,
213
+ #)
214
+ #builder.add_edge("tools", "assistant")
215
+
216
+ builder = StateGraph(MessagesState)
217
+ builder.add_node("retriever", retriever)
218
+
219
+ # Retriever ist Start und Endpunkt
220
+ builder.set_entry_point("retriever")
221
+ builder.set_finish_point("retriever")
222
+
223
+ # Compile graph
224
+ return builder.compile()