Datawithsarah commited on
Commit
7df20a1
·
1 Parent(s): 54c62fb

changes to avoid runtime error

Browse files
Files changed (2) hide show
  1. agent.py +55 -38
  2. app.py +1 -0
agent.py CHANGED
@@ -18,98 +18,113 @@ from supabase.client import Client, create_client
18
 
19
  load_dotenv()
20
 
21
- # === TOOLS === #
22
-
23
  @tool
24
- def multiply(a: int, b: int) -> int: return a * b
 
 
25
 
26
  @tool
27
- def add(a: int, b: int) -> int: return a + b
 
 
28
 
29
  @tool
30
- def subtract(a: int, b: int) -> int: return a - b
 
 
31
 
32
  @tool
33
  def divide(a: int, b: int) -> float:
 
34
  if b == 0:
35
  raise ValueError("Cannot divide by zero.")
36
  return a / b
37
 
38
  @tool
39
- def modulus(a: int, b: int) -> int: return a % b
 
 
40
 
41
  @tool
42
  def wiki_search(query: str) -> str:
43
- docs = WikipediaLoader(query=query, load_max_docs=2).load()
44
- return {"wiki_results": "\n\n---\n\n".join(doc.page_content for doc in docs)}
 
 
 
 
45
 
46
  @tool
47
  def web_search(query: str) -> str:
48
- docs = TavilySearchResults(max_results=3).invoke(query)
49
- return {"web_results": "\n\n---\n\n".join(doc.page_content for doc in docs)}
 
 
 
 
50
 
51
  @tool
52
  def arvix_search(query: str) -> str:
53
- docs = ArxivLoader(query=query, load_max_docs=3).load()
54
- return {"arvix_results": "\n\n---\n\n".join(doc.page_content[:1000] for doc in docs)}
55
-
56
- # === SYSTEM PROMPT === #
57
-
 
 
 
58
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
59
  system_prompt = f.read()
60
 
61
  sys_msg = SystemMessage(content=system_prompt)
62
 
63
- # === EMBEDDING + RETRIEVER === #
64
-
65
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
66
- supabase: Client = create_client(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY"))
67
  vector_store = SupabaseVectorStore(
68
  client=supabase,
69
  embedding=embeddings,
70
  table_name="Vector_Test",
71
  query_name="match_documents_langchain",
72
  )
 
73
  create_retriever_tool = create_retriever_tool(
74
  retriever=vector_store.as_retriever(),
75
  name="Question Search",
76
  description="A tool to retrieve similar questions from a vector store."
77
  )
78
 
79
- # === TOOL LIST === #
80
  tools = [
81
  multiply, add, subtract, divide, modulus,
82
  wiki_search, web_search, arvix_search
83
  ]
84
 
85
- # === BUILD GRAPH === #
86
  def build_graph(provider: str = "groq"):
87
  if provider == "google":
88
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
89
  elif provider == "groq":
90
  llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
91
  elif provider == "huggingface":
92
- llm = ChatHuggingFace(llm=HuggingFaceEndpoint(
93
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
94
- temperature=0))
 
 
 
95
  else:
96
- raise ValueError("Invalid provider.")
97
 
98
  llm_with_tools = llm.bind_tools(tools)
99
 
100
  def assistant(state: MessagesState):
101
- response = llm_with_tools.invoke(state["messages"])
102
- answer = response.content.strip()
103
- if "FINAL ANSWER:" not in answer:
104
- answer = f"FINAL ANSWER: {answer.strip().splitlines()[0]}"
105
- return {"messages": [AIMessage(content=answer)]}
106
 
107
  def retriever(state: MessagesState):
108
  similar = vector_store.similarity_search(state["messages"][0].content)
109
- if similar:
110
- ref = HumanMessage(content=f"Here is a similar example: \n{similar[0].page_content}")
111
- return {"messages": [sys_msg] + state["messages"] + [ref]}
112
- return {"messages": [sys_msg] + state["messages"]}
113
 
114
  builder = StateGraph(MessagesState)
115
  builder.add_node("retriever", retriever)
@@ -119,11 +134,13 @@ def build_graph(provider: str = "groq"):
119
  builder.add_edge("retriever", "assistant")
120
  builder.add_conditional_edges("assistant", tools_condition)
121
  builder.add_edge("tools", "assistant")
 
122
  return builder.compile()
123
 
124
  if __name__ == "__main__":
125
- graph = build_graph()
126
- question = "What is 12 + 4?"
127
- result = graph.invoke({"messages": [HumanMessage(content=question)]})
128
- for m in result["messages"]:
129
- print(m.content)
 
 
18
 
19
  load_dotenv()
20
 
 
 
21
  @tool
22
+ def multiply(a: int, b: int) -> int:
23
+ """Multiply two integers and return the result."""
24
+ return a * b
25
 
26
  @tool
27
+ def add(a: int, b: int) -> int:
28
+ """Add two integers and return the result."""
29
+ return a + b
30
 
31
  @tool
32
+ def subtract(a: int, b: int) -> int:
33
+ """Subtract the second integer from the first and return the result."""
34
+ return a - b
35
 
36
  @tool
37
  def divide(a: int, b: int) -> float:
38
+ """Divide the first integer by the second and return the result as float."""
39
  if b == 0:
40
  raise ValueError("Cannot divide by zero.")
41
  return a / b
42
 
43
  @tool
44
+ def modulus(a: int, b: int) -> int:
45
+ """Return the remainder when the first integer is divided by the second."""
46
+ return a % b
47
 
48
  @tool
49
  def wiki_search(query: str) -> str:
50
+ """Search Wikipedia for a query and return up to 2 results."""
51
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
52
+ formatted = "\n\n---\n\n".join(
53
+ [f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' for doc in search_docs]
54
+ )
55
+ return {"wiki_results": formatted}
56
 
57
  @tool
58
  def web_search(query: str) -> str:
59
+ """Search Tavily for a query and return up to 3 results."""
60
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
61
+ formatted = "\n\n---\n\n".join(
62
+ [f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' for doc in search_docs]
63
+ )
64
+ return {"web_results": formatted}
65
 
66
  @tool
67
  def arvix_search(query: str) -> str:
68
+ """Search Arxiv for a query and return up to 3 results."""
69
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
70
+ formatted = "\n\n---\n\n".join(
71
+ [f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>' for doc in search_docs]
72
+ )
73
+ return {"arvix_results": formatted}
74
+
75
+ # Load system prompt
76
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
77
  system_prompt = f.read()
78
 
79
  sys_msg = SystemMessage(content=system_prompt)
80
 
81
+ # Setup Supabase vector retriever
 
82
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
83
+ supabase: Client = create_client(os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_SERVICE_KEY"))
84
  vector_store = SupabaseVectorStore(
85
  client=supabase,
86
  embedding=embeddings,
87
  table_name="Vector_Test",
88
  query_name="match_documents_langchain",
89
  )
90
+
91
  create_retriever_tool = create_retriever_tool(
92
  retriever=vector_store.as_retriever(),
93
  name="Question Search",
94
  description="A tool to retrieve similar questions from a vector store."
95
  )
96
 
97
+ # Define tool list
98
  tools = [
99
  multiply, add, subtract, divide, modulus,
100
  wiki_search, web_search, arvix_search
101
  ]
102
 
103
+ # Build graph
104
  def build_graph(provider: str = "groq"):
105
  if provider == "google":
106
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
107
  elif provider == "groq":
108
  llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
109
  elif provider == "huggingface":
110
+ llm = ChatHuggingFace(
111
+ llm=HuggingFaceEndpoint(
112
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
113
+ temperature=0,
114
+ )
115
+ )
116
  else:
117
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
118
 
119
  llm_with_tools = llm.bind_tools(tools)
120
 
121
  def assistant(state: MessagesState):
122
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
 
 
123
 
124
  def retriever(state: MessagesState):
125
  similar = vector_store.similarity_search(state["messages"][0].content)
126
+ example_msg = HumanMessage(content=f"Here I provide a similar question and answer for reference: \n\n{similar[0].page_content}")
127
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
 
 
128
 
129
  builder = StateGraph(MessagesState)
130
  builder.add_node("retriever", retriever)
 
134
  builder.add_edge("retriever", "assistant")
135
  builder.add_conditional_edges("assistant", tools_condition)
136
  builder.add_edge("tools", "assistant")
137
+
138
  return builder.compile()
139
 
140
  if __name__ == "__main__":
141
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
142
+ graph = build_graph("groq")
143
+ messages = [HumanMessage(content=question)]
144
+ messages = graph.invoke({"messages": messages})
145
+ for m in messages["messages"]:
146
+ m.pretty_print()
app.py CHANGED
@@ -16,6 +16,7 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
16
  # --- Basic Agent Definition ---
17
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
18
 
 
19
  cached_answers = []
20
 
21
  class BasicAgent:
 
16
  # --- Basic Agent Definition ---
17
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
18
 
19
+
20
  cached_answers = []
21
 
22
  class BasicAgent: