Update agent.py
Browse files
agent.py
CHANGED
@@ -5,9 +5,9 @@ from langgraph.graph import START, StateGraph, MessagesState
|
|
5 |
from langgraph.prebuilt import tools_condition
|
6 |
from langgraph.prebuilt import ToolNode
|
7 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
8 |
-
from langchain_groq import ChatGroq
|
9 |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
|
10 |
-
from langchain_community.tools.tavily_search import TavilySearchResults
|
11 |
from langchain_community.document_loaders import WikipediaLoader
|
12 |
from langchain_community.document_loaders import ArxivLoader
|
13 |
from langchain_community.vectorstores import SupabaseVectorStore
|
@@ -84,6 +84,7 @@ def wiki_search(query: str) -> str:
|
|
84 |
])
|
85 |
return {"wiki_results": formatted_search_docs}
|
86 |
|
|
|
87 |
@tool
|
88 |
def web_search(query: str) -> str:
|
89 |
"""Search Tavily for a query and return maximum 3 results.
|
@@ -98,6 +99,8 @@ def web_search(query: str) -> str:
|
|
98 |
])
|
99 |
return {"web_results": formatted_search_docs}
|
100 |
|
|
|
|
|
101 |
@tool
|
102 |
def arvix_search(query: str) -> str:
|
103 |
"""Search Arxiv for a query and return maximum 3 result.
|
@@ -118,18 +121,19 @@ def arvix_search(query: str) -> str:
|
|
118 |
#with open("system_prompt.txt", "r", encoding="utf-8") as f:
|
119 |
# system_prompt = f.read()
|
120 |
|
121 |
-
system_prompt = """
|
122 |
Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
|
123 |
FINAL ANSWER: [YOUR FINAL ANSWER].
|
124 |
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
|
125 |
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
|
126 |
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
|
127 |
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
128 |
-
Your answer should only start with "FINAL ANSWER: ", then follows with the answer."""
|
129 |
|
130 |
# System message
|
131 |
sys_msg = SystemMessage(content=system_prompt)
|
132 |
|
|
|
133 |
# build a retriever
|
134 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
|
135 |
supabase: Client = create_client(
|
@@ -146,7 +150,7 @@ create_retriever_tool = create_retriever_tool(
|
|
146 |
name="Question Search",
|
147 |
description="A tool to retrieve similar questions from a vector store.",
|
148 |
)
|
149 |
-
|
150 |
|
151 |
|
152 |
tools = [
|
@@ -156,7 +160,7 @@ tools = [
|
|
156 |
divide,
|
157 |
modulus,
|
158 |
wiki_search,
|
159 |
-
web_search,
|
160 |
arvix_search,
|
161 |
]
|
162 |
|
@@ -197,11 +201,12 @@ def build_graph(provider: str = "huggingface"):
|
|
197 |
return {"messages": [sys_msg] + state["messages"] + [example_msg]}
|
198 |
|
199 |
builder = StateGraph(MessagesState)
|
200 |
-
builder.add_node("retriever", retriever)
|
201 |
builder.add_node("assistant", assistant)
|
202 |
builder.add_node("tools", ToolNode(tools))
|
203 |
-
builder.add_edge(START, "retriever")
|
204 |
-
builder.add_edge(
|
|
|
205 |
builder.add_conditional_edges(
|
206 |
"assistant",
|
207 |
tools_condition,
|
@@ -215,7 +220,7 @@ def build_graph(provider: str = "huggingface"):
|
|
215 |
if __name__ == "__main__":
|
216 |
question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
|
217 |
# Build the graph
|
218 |
-
graph = build_graph(provider="
|
219 |
# Run the graph
|
220 |
messages = [HumanMessage(content=question)]
|
221 |
messages = graph.invoke({"messages": messages})
|
|
|
5 |
from langgraph.prebuilt import tools_condition
|
6 |
from langgraph.prebuilt import ToolNode
|
7 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
8 |
+
#from langchain_groq import ChatGroq
|
9 |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
|
10 |
+
#from langchain_community.tools.tavily_search import TavilySearchResults
|
11 |
from langchain_community.document_loaders import WikipediaLoader
|
12 |
from langchain_community.document_loaders import ArxivLoader
|
13 |
from langchain_community.vectorstores import SupabaseVectorStore
|
|
|
84 |
])
|
85 |
return {"wiki_results": formatted_search_docs}
|
86 |
|
87 |
+
"""
|
88 |
@tool
|
89 |
def web_search(query: str) -> str:
|
90 |
"""Search Tavily for a query and return maximum 3 results.
|
|
|
99 |
])
|
100 |
return {"web_results": formatted_search_docs}
|
101 |
|
102 |
+
"""
|
103 |
+
|
104 |
@tool
|
105 |
def arvix_search(query: str) -> str:
|
106 |
"""Search Arxiv for a query and return maximum 3 result.
|
|
|
121 |
#with open("system_prompt.txt", "r", encoding="utf-8") as f:
|
122 |
# system_prompt = f.read()
|
123 |
|
124 |
+
system_prompt = """You are a helpful assistant tasked with answering questions using a set of tools.
|
125 |
Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
|
126 |
FINAL ANSWER: [YOUR FINAL ANSWER].
|
127 |
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
|
128 |
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
|
129 |
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
|
130 |
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
131 |
+
Your answer should only start with "FINAL ANSWER: ", then follows with the answer.""".strip()
|
132 |
|
133 |
# System message
|
134 |
sys_msg = SystemMessage(content=system_prompt)
|
135 |
|
136 |
+
"""
|
137 |
# build a retriever
|
138 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
|
139 |
supabase: Client = create_client(
|
|
|
150 |
name="Question Search",
|
151 |
description="A tool to retrieve similar questions from a vector store.",
|
152 |
)
|
153 |
+
"""
|
154 |
|
155 |
|
156 |
tools = [
|
|
|
160 |
divide,
|
161 |
modulus,
|
162 |
wiki_search,
|
163 |
+
#web_search,
|
164 |
arvix_search,
|
165 |
]
|
166 |
|
|
|
201 |
return {"messages": [sys_msg] + state["messages"] + [example_msg]}
|
202 |
|
203 |
builder = StateGraph(MessagesState)
|
204 |
+
#builder.add_node("retriever", retriever)
|
205 |
builder.add_node("assistant", assistant)
|
206 |
builder.add_node("tools", ToolNode(tools))
|
207 |
+
#builder.add_edge(START, "retriever")
|
208 |
+
builder.add_edge(START, "assistant")
|
209 |
+
#builder.add_edge("retriever", "assistant")
|
210 |
builder.add_conditional_edges(
|
211 |
"assistant",
|
212 |
tools_condition,
|
|
|
220 |
if __name__ == "__main__":
|
221 |
question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
|
222 |
# Build the graph
|
223 |
+
graph = build_graph(provider="huggingface")
|
224 |
# Run the graph
|
225 |
messages = [HumanMessage(content=question)]
|
226 |
messages = graph.invoke({"messages": messages})
|