baixianger commited on
Commit
3839c42
·
1 Parent(s): 66a348b
Files changed (1) hide show
  1. agent.py +21 -11
agent.py CHANGED
@@ -13,6 +13,7 @@ from langchain_community.document_loaders import ArxivLoader
13
  from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
 
16
  from supabase.client import Client, create_client
17
 
18
  load_dotenv()
@@ -111,16 +112,7 @@ def arvix_search(query: str) -> str:
111
  ])
112
  return {"arvix_results": formatted_search_docs}
113
 
114
- tools = [
115
- multiply,
116
- add,
117
- subtract,
118
- divide,
119
- modulus,
120
- wiki_search,
121
- web_search,
122
- arvix_search,
123
- ]
124
 
125
  # load the system prompt from the file
126
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
@@ -140,6 +132,24 @@ vector_store = SupabaseVectorStore(
140
  table_name="documents",
141
  query_name="match_documents_langchain",
142
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  # Build graph function
145
  def build_graph(provider: str = "groq"):
@@ -162,7 +172,7 @@ def build_graph(provider: str = "groq"):
162
  else:
163
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
164
  # Bind tools to LLM
165
- llm_with_tools = llm.bind_tools(tools, tool_choice="Question Search")
166
 
167
  # Node
168
  def assistant(state: MessagesState):
 
13
  from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
+ from langchain.tools.retriever import create_retriever_tool
17
  from supabase.client import Client, create_client
18
 
19
  load_dotenv()
 
112
  ])
113
  return {"arvix_results": formatted_search_docs}
114
 
115
+
 
 
 
 
 
 
 
 
 
116
 
117
  # load the system prompt from the file
118
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
 
132
  table_name="documents",
133
  query_name="match_documents_langchain",
134
  )
135
+ create_retriever_tool = create_retriever_tool(
136
+ retriever=vector_store.as_retriever(),
137
+ name="Question Search",
138
+ description="A tool to retrieve similar questions from a vector store.",
139
+ )
140
+
141
+
142
+
143
+ tools = [
144
+ multiply,
145
+ add,
146
+ subtract,
147
+ divide,
148
+ modulus,
149
+ wiki_search,
150
+ web_search,
151
+ arvix_search,
152
+ ]
153
 
154
  # Build graph function
155
  def build_graph(provider: str = "groq"):
 
172
  else:
173
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
174
  # Bind tools to LLM
175
+ llm_with_tools = llm.bind_tools(tools)
176
 
177
  # Node
178
  def assistant(state: MessagesState):