baixianger commited on
Commit
41cb4a2
·
1 Parent(s): 3526766

add retriever tool

Browse files
Files changed (5) hide show
  1. agent.py +70 -15
  2. requirements.txt +7 -1
  3. supabase_docs.csv +0 -0
  4. system_prompt.txt +3 -1
  5. test.ipynb +214 -45
agent.py CHANGED
@@ -1,15 +1,22 @@
1
  """LangGraph Agent"""
2
- import dotenv
3
- from langgraph.graph import MessagesState
4
- from langgraph.graph import START, StateGraph
5
  from langgraph.prebuilt import tools_condition
6
  from langgraph.prebuilt import ToolNode
7
  from langchain_google_genai import ChatGoogleGenerativeAI
 
 
8
  from langchain_community.tools.tavily_search import TavilySearchResults
9
  from langchain_community.document_loaders import WikipediaLoader
10
  from langchain_community.document_loaders import ArxivLoader
 
 
11
  from langchain_core.tools import tool
12
- from langchain_core.messages import SystemMessage
 
 
 
13
 
14
  @tool
15
  def multiply(a: int, b: int) -> int:
@@ -105,6 +112,25 @@ def arvix_search(query: str) -> str:
105
  ])
106
  return {"arvix_results": formatted_search_docs}
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  tools = [
109
  multiply,
110
  add,
@@ -114,14 +140,9 @@ tools = [
114
  wiki_search,
115
  web_search,
116
  arvix_search,
 
117
  ]
118
 
119
-
120
- # Load environment variables from .env file
121
- dotenv.load_dotenv()
122
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash")
123
- llm_with_tools = llm.bind_tools(tools)
124
-
125
  # load the system prompt from the file
126
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
127
  system_prompt = f.read()
@@ -129,14 +150,35 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
129
  # System message
130
  sys_msg = SystemMessage(content=system_prompt)
131
 
132
- # Node
133
- def assistant(state: MessagesState):
134
- """Assistant node"""
135
- return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]}
136
 
137
  # Build graph function
138
- def build_graph():
139
  """Build the graph"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  builder = StateGraph(MessagesState)
142
  builder.add_node("assistant", assistant)
@@ -150,3 +192,16 @@ def build_graph():
150
 
151
  # Compile graph
152
  return builder.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """LangGraph Agent"""
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
  from langgraph.prebuilt import tools_condition
6
  from langgraph.prebuilt import ToolNode
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_groq import ChatGroq
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader
12
  from langchain_community.document_loaders import ArxivLoader
13
+ from langchain_community.vectorstores import SupabaseVectorStore
14
+ from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
+ from langchain.tools.retriever import create_retriever_tool
17
+ from supabase.client import Client, create_client
18
+
19
+ load_dotenv()
20
 
21
  @tool
22
  def multiply(a: int, b: int) -> int:
 
112
  ])
113
  return {"arvix_results": formatted_search_docs}
114
 
115
+ # build a retriever tool
116
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
117
+ supabase: Client = create_client(
118
+ os.environ.get("SUPABASE_URL"),
119
+ os.environ.get("SUPABASE_SERVICE_KEY"))
120
+ vector_store = SupabaseVectorStore(
121
+ client=supabase,
122
+ embedding= embeddings,
123
+ table_name="documents",
124
+ query_name="match_documents_langchain",
125
+ )
126
+ question_retrieve_tool = create_retriever_tool(
127
+ vector_store.as_retriever(),
128
+ "Question Retriever",
129
+ "Find similar questions in the vector database for the given question.",
130
+ )
131
+
132
+
133
+
134
  tools = [
135
  multiply,
136
  add,
 
140
  wiki_search,
141
  web_search,
142
  arvix_search,
143
+ question_retrieve_tool
144
  ]
145
 
 
 
 
 
 
 
146
  # load the system prompt from the file
147
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
148
  system_prompt = f.read()
 
150
  # System message
151
  sys_msg = SystemMessage(content=system_prompt)
152
 
153
+
 
 
 
154
 
155
  # Build graph function
156
+ def build_graph(provider: str = "groq"):
157
  """Build the graph"""
158
+ # Load environment variables from .env file
159
+ if provider == "google":
160
+ # Google Gemini
161
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
162
+ elif provider == "groq":
163
+ # Groq https://console.groq.com/docs/models
164
+ llm = ChatGroq(model="gemma2-9b-it", temperature=0) # optional : qwen-qwq-32b
165
+ elif provider == "huggingface":
166
+ # TODO: Add huggingface endpoint
167
+ llm = ChatHuggingFace(
168
+ llm=HuggingFaceEndpoint(
169
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
170
+ temperature=0,
171
+ ),
172
+ )
173
+ else:
174
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
175
+ # Bind tools to LLM
176
+ llm_with_tools = llm.bind_tools(tools)
177
+
178
+ # Node
179
+ def assistant(state: MessagesState):
180
+ """Assistant node"""
181
+ return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]}
182
 
183
  builder = StateGraph(MessagesState)
184
  builder.add_node("assistant", assistant)
 
192
 
193
  # Compile graph
194
  return builder.compile()
195
+
196
+ # test
197
+ if __name__ == "__main__":
198
+ question = "Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2016?"
199
+
200
+ # Build the graph
201
+ graph = build_graph(provider="groq")
202
+ # Run the graph
203
+ messages = [HumanMessage(content=question)]
204
+ messages = graph.invoke({"messages": messages})
205
+ answer = messages[-1].content
206
+ print(f"Question: {question}")
207
+ print(f"{answer}")
requirements.txt CHANGED
@@ -4,9 +4,15 @@ langchain
4
  langchain-community
5
  langchain-core
6
  langchain-google-genai
 
 
7
  langchain-tavily
 
8
  langgraph
 
 
9
  arxiv
10
  pymupdf
11
  wikipedia
12
- dotenv
 
 
4
  langchain-community
5
  langchain-core
6
  langchain-google-genai
7
+ langchain-huggingface
8
+ langchain-groq
9
  langchain-tavily
10
+ langchain-chroma
11
  langgraph
12
+ huggingface_hub
13
+ supabase
14
  arxiv
15
  pymupdf
16
  wikipedia
17
+ pgvector
18
+ python-dotenv
supabase_docs.csv ADDED
The diff for this file is too large to render. See raw diff
 
system_prompt.txt CHANGED
@@ -33,4 +33,6 @@ Tools:
33
  Final Answer: Rd5
34
  ==========================
35
 
36
- Now, I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
 
 
 
33
  Final Answer: Rd5
34
  ==========================
35
 
36
+ Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
37
+ FINAL ANSWER: [YOUR FINAL ANSWER].
38
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
test.ipynb CHANGED
@@ -1,8 +1,16 @@
1
  {
2
  "cells": [
 
 
 
 
 
 
 
 
3
  {
4
  "cell_type": "code",
5
- "execution_count": 50,
6
  "id": "14e3f417",
7
  "metadata": {},
8
  "outputs": [],
@@ -21,7 +29,7 @@
21
  },
22
  {
23
  "cell_type": "code",
24
- "execution_count": 51,
25
  "id": "5e2da6fc",
26
  "metadata": {},
27
  "outputs": [
@@ -30,45 +38,28 @@
30
  "output_type": "stream",
31
  "text": [
32
  "==================================================\n",
33
- "Task ID: f0f46385-fc03-4599-b5d3-f56496c3e69f\n",
34
- "Question: In terms of geographical distance between capital cities, which 2 countries are the furthest from each other within the ASEAN bloc according to wikipedia? Answer using a comma separated list, ordering the countries by alphabetical order.\n",
35
  "Level: 2\n",
36
- "Final Answer: Indonesia, Myanmar\n",
37
  "Annotator Metadata: \n",
38
  " ├── Steps: \n",
39
- " │ ├── 1. Search the web for \"ASEAN bloc\".\n",
40
- " │ ├── 2. Click the Wikipedia result for the ASEAN Free Trade Area.\n",
41
- " │ ├── 3. Scroll down to find the list of member states.\n",
42
- " │ ├── 4. Click into the Wikipedia pages for each member state, and note its capital.\n",
43
- " │ ├── 5. Search the web for the distance between the first two capitals. The results give travel distance, not geographic distance, which might affect the answer.\n",
44
- " │ ├── 6. Thinking it might be faster to judge the distance by looking at a map, search the web for \"ASEAN bloc\" and click into the images tab.\n",
45
- " │ ├── 7. View a map of the member countries. Since they're clustered together in an arrangement that's not very linear, it's difficult to judge distances by eye.\n",
46
- " │ ├── 8. Return to the Wikipedia page for each country. Click the GPS coordinates for each capital to get the coordinates in decimal notation.\n",
47
- " │ ├── 9. Place all these coordinates into a spreadsheet.\n",
48
- " │ ├── 10. Write formulas to calculate the distance between each capital.\n",
49
- " │ ├── 11. Write formula to get the largest distance value in the spreadsheet.\n",
50
- " │ ├── 12. Note which two capitals that value corresponds to: Jakarta and Naypyidaw.\n",
51
- " ├── 13. Return to the Wikipedia pages to see which countries those respective capitals belong to: Indonesia, Myanmar.\n",
52
- " ├── Number of steps: 13\n",
53
- " ├── How long did this take?: 45 minutes\n",
54
  " ├── Tools:\n",
55
- " │ ├── 1. Search engine\n",
56
- " │ ├── 2. Web browser\n",
57
- " │ ├── 3. Microsoft Excel / Google Sheets\n",
58
- " └── Number of tools: 3\n",
59
- "==================================================\n",
60
- "Task ID: cca530fc-4052-43b2-b130-b30968d8aa44\n",
61
- "Question: Review the chess position provided in the image. It is black's turn. Provide the correct next move for black which guarantees a win. Please provide your response in algebraic notation.\n",
62
- "Level: 1\n",
63
- "Final Answer: Rd5\n",
64
- "Annotator Metadata: \n",
65
- " ├── Steps: \n",
66
- " │ ├── Step 1: Evaluate the position of the pieces in the chess position\n",
67
- " │ ├── Step 2: Report the best move available for black: \"Rd5\"\n",
68
- " ├── Number of steps: 2\n",
69
- " ├── How long did this take?: 10 minutes\n",
70
- " ├── Tools:\n",
71
- " │ ├── 1. Image recognition tools\n",
72
  " └── Number of tools: 1\n",
73
  "==================================================\n"
74
  ]
@@ -80,7 +71,7 @@
80
  "\n",
81
  "import random\n",
82
  "# random.seed(42)\n",
83
- "random_samples = random.sample(json_QA, 2)\n",
84
  "for sample in random_samples:\n",
85
  " print(\"=\" * 50)\n",
86
  " print(f\"Task ID: {sample['task_id']}\")\n",
@@ -100,6 +91,118 @@
100
  "print(\"=\" * 50)"
101
  ]
102
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  {
104
  "cell_type": "code",
105
  "execution_count": 31,
@@ -216,6 +319,14 @@
216
  " print(f\" ├── {tool}: {count}\")"
217
  ]
218
  },
 
 
 
 
 
 
 
 
219
  {
220
  "cell_type": "code",
221
  "execution_count": 55,
@@ -301,22 +412,45 @@
301
  },
302
  {
303
  "cell_type": "code",
304
- "execution_count": 46,
305
  "id": "42fde0f8",
306
  "metadata": {},
307
  "outputs": [],
308
  "source": [
309
  "import dotenv\n",
310
- "from langgraph.graph import MessagesState\n",
311
- "from langchain_core.messages import HumanMessage, SystemMessage\n",
312
- "from langgraph.graph import START, StateGraph\n",
313
  "from langgraph.prebuilt import tools_condition\n",
314
  "from langgraph.prebuilt import ToolNode\n",
315
  "from langchain_google_genai import ChatGoogleGenerativeAI\n",
 
316
  "from langchain_community.tools.tavily_search import TavilySearchResults\n",
317
  "from langchain_community.document_loaders import WikipediaLoader\n",
318
  "from langchain_community.document_loaders import ArxivLoader\n",
 
 
 
319
  "from langchain_core.tools import tool\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  "\n",
321
  "@tool\n",
322
  "def multiply(a: int, b: int) -> int:\n",
@@ -412,6 +546,20 @@
412
  " ])\n",
413
  " return {\"arvix_results\": formatted_search_docs}\n",
414
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  "tools = [\n",
416
  " multiply,\n",
417
  " add,\n",
@@ -421,11 +569,9 @@
421
  " wiki_search,\n",
422
  " web_search,\n",
423
  " arvix_search,\n",
 
424
  "]\n",
425
  "\n",
426
- "\n",
427
- "# Load environment variables from .env file\n",
428
- "dotenv.load_dotenv()\n",
429
  "llm = ChatGoogleGenerativeAI(model=\"gemini-2.0-flash\")\n",
430
  "llm_with_tools = llm.bind_tools(tools)"
431
  ]
@@ -489,6 +635,29 @@
489
  "\n",
490
  "display(Image(graph.get_graph(xray=True).draw_mermaid_png()))"
491
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
  }
493
  ],
494
  "metadata": {
 
1
  {
2
  "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "d0cc4adf",
6
+ "metadata": {},
7
+ "source": [
8
+ "### Question data"
9
+ ]
10
+ },
11
  {
12
  "cell_type": "code",
13
+ "execution_count": 2,
14
  "id": "14e3f417",
15
  "metadata": {},
16
  "outputs": [],
 
29
  },
30
  {
31
  "cell_type": "code",
32
+ "execution_count": 3,
33
  "id": "5e2da6fc",
34
  "metadata": {},
35
  "outputs": [
 
38
  "output_type": "stream",
39
  "text": [
40
  "==================================================\n",
41
+ "Task ID: ed58682d-bc52-4baa-9eb0-4eb81e1edacc\n",
42
+ "Question: What is the last word before the second chorus of the King of Pop's fifth single from his sixth studio album?\n",
43
  "Level: 2\n",
44
+ "Final Answer: stare\n",
45
  "Annotator Metadata: \n",
46
  " ├── Steps: \n",
47
+ " │ ├── 1. Google searched \"King of Pop\".\n",
48
+ " │ ├── 2. Clicked on Michael Jackson's Wikipedia.\n",
49
+ " │ ├── 3. Scrolled down to \"Discography\".\n",
50
+ " │ ├── 4. Clicked on the sixth album, \"Thriller\".\n",
51
+ " │ ├── 5. Looked under \"Singles from Thriller\".\n",
52
+ " │ ├── 6. Clicked on the fifth single, \"Human Nature\".\n",
53
+ " │ ├── 7. Google searched \"Human Nature Michael Jackson Lyrics\".\n",
54
+ " │ ├── 8. Looked at the opening result with full lyrics sourced by Musixmatch.\n",
55
+ " │ ├── 9. Looked for repeating lyrics to determine the chorus.\n",
56
+ " │ ├── 10. Determined the chorus begins with \"If they say\" and ends with \"Does he do me that way?\"\n",
57
+ " │ ├── 11. Found the second instance of the chorus within the lyrics.\n",
58
+ " │ ├── 12. Noted the last word before the second chorus - \"stare\".\n",
59
+ " ├── Number of steps: 12\n",
60
+ " ├── How long did this take?: 20 minutes\n",
 
61
  " ├── Tools:\n",
62
+ " │ ├── Web Browser\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  " └── Number of tools: 1\n",
64
  "==================================================\n"
65
  ]
 
71
  "\n",
72
  "import random\n",
73
  "# random.seed(42)\n",
74
+ "random_samples = random.sample(json_QA, 1)\n",
75
  "for sample in random_samples:\n",
76
  " print(\"=\" * 50)\n",
77
  " print(f\"Task ID: {sample['task_id']}\")\n",
 
91
  "print(\"=\" * 50)"
92
  ]
93
  },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 48,
97
+ "id": "4bb02420",
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": [
101
+ "### build a vector database based on the metadata.jsonl\n",
102
+ "# https://python.langchain.com/docs/integrations/vectorstores/supabase/\n",
103
+ "import os\n",
104
+ "from dotenv import load_dotenv\n",
105
+ "from langchain_huggingface import HuggingFaceEmbeddings\n",
106
+ "from langchain_community.vectorstores import SupabaseVectorStore\n",
107
+ "from supabase.client import Client, create_client\n",
108
+ "\n",
109
+ "\n",
110
+ "load_dotenv()\n",
111
+ "embeddings = HuggingFaceEmbeddings(model_name=\"sentence-transformers/all-mpnet-base-v2\") # dim=768\n",
112
+ "\n",
113
+ "supabase_url = os.environ.get(\"SUPABASE_URL\")\n",
114
+ "supabase_key = os.environ.get(\"SUPABASE_SERVICE_KEY\")\n",
115
+ "supabase: Client = create_client(supabase_url, supabase_key)"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "id": "a070b955",
122
+ "metadata": {},
123
+ "outputs": [],
124
+ "source": [
125
+ "# wrap the metadata.jsonl's questions and answers into a list of document\n",
126
+ "from langchain.schema import Document\n",
127
+ "docs = []\n",
128
+ "for sample in json_QA:\n",
129
+ " content = f\"Question : {sample['Question']}\\n\\nFinal answer : {sample['Final answer']}\"\n",
130
+ " doc = {\n",
131
+ " \"content\" : content,\n",
132
+ " \"metadata\" : { # meatadata的格式必须时source键,否则会报错\n",
133
+ " \"source\" : sample['task_id']\n",
134
+ " },\n",
135
+ " \"embedding\" : embeddings.embed_query(content),\n",
136
+ " }\n",
137
+ " docs.append(doc)\n",
138
+ "\n",
139
+ "# upload the documents to the vector database\n",
140
+ "try:\n",
141
+ " response = (\n",
142
+ " supabase.table(\"documents\")\n",
143
+ " .insert(docs)\n",
144
+ " .execute()\n",
145
+ " )\n",
146
+ "except Exception as exception:\n",
147
+ " print(\"Error inserting data into Supabase:\", exception)\n",
148
+ "\n",
149
+ "# ALTERNATIVE : Save the documents (a list of dict) into a csv file, and manually upload it to Supabase\n",
150
+ "# import pandas as pd\n",
151
+ "# df = pd.DataFrame(docs)\n",
152
+ "# df.to_csv('supabase_docs.csv', index=False)"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": 54,
158
+ "id": "77fb9dbb",
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": [
162
+ "# add items to vector database\n",
163
+ "vector_store = SupabaseVectorStore(\n",
164
+ " client=supabase,\n",
165
+ " embedding= embeddings,\n",
166
+ " table_name=\"documents\",\n",
167
+ " query_name=\"match_documents_langchain\",\n",
168
+ ")\n",
169
+ "retriever = vector_store.as_retriever()"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": 55,
175
+ "id": "12a05971",
176
+ "metadata": {},
177
+ "outputs": [
178
+ {
179
+ "name": "stderr",
180
+ "output_type": "stream",
181
+ "text": [
182
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
183
+ "To disable this warning, you can either:\n",
184
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
185
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
186
+ ]
187
+ },
188
+ {
189
+ "data": {
190
+ "text/plain": [
191
+ "Document(metadata={'source': '840bfca7-4f7b-481a-8794-c560c340185d'}, page_content='Question : On June 6, 2023, an article by Carolyn Collins Petersen was published in Universe Today. This article mentions a team that produced a paper about their observations, linked at the bottom of the article. Find this paper. Under what NASA award number was the work performed by R. G. Arendt supported by?\\n\\nFinal answer : 80GSFC21M0002')"
192
+ ]
193
+ },
194
+ "execution_count": 55,
195
+ "metadata": {},
196
+ "output_type": "execute_result"
197
+ }
198
+ ],
199
+ "source": [
200
+ "query = \"On June 6, 2023, an article by Carolyn Collins Petersen was published in Universe Today. This article mentions a team that produced a paper about their observations, linked at the bottom of the article. Find this paper. Under what NASA award number was the work performed by R. G. Arendt supported by?\"\n",
201
+ "# matched_docs = vector_store.similarity_search(query, 2)\n",
202
+ "docs = retriever.invoke(query)\n",
203
+ "docs[0]"
204
+ ]
205
+ },
206
  {
207
  "cell_type": "code",
208
  "execution_count": 31,
 
319
  " print(f\" ├── {tool}: {count}\")"
320
  ]
321
  },
322
+ {
323
+ "cell_type": "markdown",
324
+ "id": "5efee12a",
325
+ "metadata": {},
326
+ "source": [
327
+ "#### Graph"
328
+ ]
329
+ },
330
  {
331
  "cell_type": "code",
332
  "execution_count": 55,
 
412
  },
413
  {
414
  "cell_type": "code",
415
+ "execution_count": null,
416
  "id": "42fde0f8",
417
  "metadata": {},
418
  "outputs": [],
419
  "source": [
420
  "import dotenv\n",
421
+ "from langgraph.graph import MessagesState, START, StateGraph\n",
 
 
422
  "from langgraph.prebuilt import tools_condition\n",
423
  "from langgraph.prebuilt import ToolNode\n",
424
  "from langchain_google_genai import ChatGoogleGenerativeAI\n",
425
+ "from langchain_huggingface import HuggingFaceEmbeddings\n",
426
  "from langchain_community.tools.tavily_search import TavilySearchResults\n",
427
  "from langchain_community.document_loaders import WikipediaLoader\n",
428
  "from langchain_community.document_loaders import ArxivLoader\n",
429
+ "from langchain_community.vectorstores import SupabaseVectorStore\n",
430
+ "from langchain.tools.retriever import create_retriever_tool\n",
431
+ "from langchain_core.messages import HumanMessage, SystemMessage\n",
432
  "from langchain_core.tools import tool\n",
433
+ "from supabase.client import Client, create_client\n",
434
+ "\n",
435
+ "# Define the retriever from supabase\n",
436
+ "load_dotenv()\n",
437
+ "embeddings = HuggingFaceEmbeddings(model_name=\"sentence-transformers/all-mpnet-base-v2\") # dim=768\n",
438
+ "\n",
439
+ "supabase_url = os.environ.get(\"SUPABASE_URL\")\n",
440
+ "supabase_key = os.environ.get(\"SUPABASE_SERVICE_KEY\")\n",
441
+ "supabase: Client = create_client(supabase_url, supabase_key)\n",
442
+ "vector_store = SupabaseVectorStore(\n",
443
+ " client=supabase,\n",
444
+ " embedding= embeddings,\n",
445
+ " table_name=\"documents\",\n",
446
+ " query_name=\"match_documents_langchain\",\n",
447
+ ")\n",
448
+ "\n",
449
+ "question_retrieve_tool = create_retriever_tool(\n",
450
+ " vector_store.as_retriever(),\n",
451
+ " \"Question Retriever\",\n",
452
+ " \"Find similar questions in the vector database for the given question.\",\n",
453
+ ")\n",
454
  "\n",
455
  "@tool\n",
456
  "def multiply(a: int, b: int) -> int:\n",
 
546
  " ])\n",
547
  " return {\"arvix_results\": formatted_search_docs}\n",
548
  "\n",
549
+ "@tool\n",
550
+ "def similar_question_search(question: str) -> str:\n",
551
+ " \"\"\"Search the vector database for similar questions and return the first results.\n",
552
+ " \n",
553
+ " Args:\n",
554
+ " question: the question human provided.\"\"\"\n",
555
+ " matched_docs = vector_store.similarity_search(query, 3)\n",
556
+ " formatted_search_docs = \"\\n\\n---\\n\\n\".join(\n",
557
+ " [\n",
558
+ " f'<Document source=\"{doc.metadata[\"source\"]}\" page=\"{doc.metadata.get(\"page\", \"\")}\"/>\\n{doc.page_content[:1000]}\\n</Document>'\n",
559
+ " for doc in matched_docs\n",
560
+ " ])\n",
561
+ " return {\"similar_questions\": formatted_search_docs}\n",
562
+ "\n",
563
  "tools = [\n",
564
  " multiply,\n",
565
  " add,\n",
 
569
  " wiki_search,\n",
570
  " web_search,\n",
571
  " arvix_search,\n",
572
+ " question_retrieve_tool\n",
573
  "]\n",
574
  "\n",
 
 
 
575
  "llm = ChatGoogleGenerativeAI(model=\"gemini-2.0-flash\")\n",
576
  "llm_with_tools = llm.bind_tools(tools)"
577
  ]
 
635
  "\n",
636
  "display(Image(graph.get_graph(xray=True).draw_mermaid_png()))"
637
  ]
638
+ },
639
+ {
640
+ "cell_type": "code",
641
+ "execution_count": null,
642
+ "id": "5987d58c",
643
+ "metadata": {},
644
+ "outputs": [],
645
+ "source": [
646
+ "question = \"\"\n",
647
+ "messages = [HumanMessage(content=question)]\n",
648
+ "messages = graph.invoke({\"messages\": messages})"
649
+ ]
650
+ },
651
+ {
652
+ "cell_type": "code",
653
+ "execution_count": null,
654
+ "id": "330cbf17",
655
+ "metadata": {},
656
+ "outputs": [],
657
+ "source": [
658
+ "for m in messages['messages']:\n",
659
+ " m.pretty_print()"
660
+ ]
661
  }
662
  ],
663
  "metadata": {