Datawithsarah commited on
Commit
54c62fb
·
1 Parent(s): aa8a441

refined prompt

Browse files
Files changed (3) hide show
  1. agent.py +46 -123
  2. app.py +7 -7
  3. system_prompt.txt +18 -5
agent.py CHANGED
@@ -18,166 +18,98 @@ from supabase.client import Client, create_client
18
 
19
  load_dotenv()
20
 
 
 
21
  @tool
22
- def multiply(a: int, b: int) -> int:
23
- """Multiply two numbers.
24
- Args:
25
- a: first int
26
- b: second int
27
- """
28
- return a * b
29
 
30
  @tool
31
- def add(a: int, b: int) -> int:
32
- """Add two numbers.
33
- Args:
34
- a: first int
35
- b: second int
36
- """
37
- return a + b
38
 
39
  @tool
40
- def subtract(a: int, b: int) -> int:
41
- """Subtract two numbers.
42
- Args:
43
- a: first int
44
- b: second int
45
- """
46
- return a - b
47
 
48
  @tool
49
- def divide(a: int, b: int) -> int:
50
- """Divide two numbers.
51
- Args:
52
- a: first int
53
- b: second int
54
- """
55
  if b == 0:
56
  raise ValueError("Cannot divide by zero.")
57
  return a / b
58
 
59
  @tool
60
- def modulus(a: int, b: int) -> int:
61
- """Get the modulus of two numbers.
62
- Args:
63
- a: first int
64
- b: second int
65
- """
66
- return a % b
67
 
68
  @tool
69
  def wiki_search(query: str) -> str:
70
- """Search Wikipedia for a query and return maximum 2 results.
71
- Args:
72
- query: The search query."""
73
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
74
- formatted_search_docs = "\n\n---\n\n".join(
75
- [
76
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
77
- for doc in search_docs
78
- ])
79
- return {"wiki_results": formatted_search_docs}
80
 
81
  @tool
82
  def web_search(query: str) -> str:
83
- """Search Tavily for a query and return maximum 3 results.
84
- Args:
85
- query: The search query."""
86
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
87
- formatted_search_docs = "\n\n---\n\n".join(
88
- [
89
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
90
- for doc in search_docs
91
- ])
92
- return {"web_results": formatted_search_docs}
93
 
94
  @tool
95
  def arvix_search(query: str) -> str:
96
- """Search Arxiv for a query and return maximum 3 result.
97
- Args:
98
- query: The search query."""
99
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
100
- formatted_search_docs = "\n\n---\n\n".join(
101
- [
102
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
103
- for doc in search_docs
104
- ])
105
- return {"arvix_results": formatted_search_docs}
106
-
107
 
 
108
 
109
- # load the system prompt from the file
110
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
111
  system_prompt = f.read()
112
 
113
- # System message
114
  sys_msg = SystemMessage(content=system_prompt)
115
 
116
- # build a retriever
117
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
118
- supabase: Client = create_client(
119
- os.environ.get("SUPABASE_URL"),
120
- os.environ.get("SUPABASE_SERVICE_KEY"))
121
  vector_store = SupabaseVectorStore(
122
  client=supabase,
123
- embedding= embeddings,
124
  table_name="Vector_Test",
125
  query_name="match_documents_langchain",
126
  )
127
  create_retriever_tool = create_retriever_tool(
128
  retriever=vector_store.as_retriever(),
129
  name="Question Search",
130
- description="A tool to retrieve similar questions from a vector store.",
131
  )
132
 
133
-
134
-
135
  tools = [
136
- multiply,
137
- add,
138
- subtract,
139
- divide,
140
- modulus,
141
- wiki_search,
142
- web_search,
143
- arvix_search,
144
  ]
145
 
146
- # Build graph function
147
  def build_graph(provider: str = "groq"):
148
- """Build the graph"""
149
- # Load environment variables from .env file
150
  if provider == "google":
151
- # Google Gemini
152
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
153
  elif provider == "groq":
154
- # Groq https://console.groq.com/docs/models
155
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
156
  elif provider == "huggingface":
157
- # TODO: Add huggingface endpoint
158
- llm = ChatHuggingFace(
159
- llm=HuggingFaceEndpoint(
160
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
161
- temperature=0,
162
- ),
163
- )
164
  else:
165
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
166
- # Bind tools to LLM
167
  llm_with_tools = llm.bind_tools(tools)
168
 
169
- # Node
170
  def assistant(state: MessagesState):
171
- """Assistant node"""
172
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
 
173
 
174
  def retriever(state: MessagesState):
175
- """Retriever node"""
176
- similar_question = vector_store.similarity_search(state["messages"][0].content)
177
- example_msg = HumanMessage(
178
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
179
- )
180
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
181
 
182
  builder = StateGraph(MessagesState)
183
  builder.add_node("retriever", retriever)
@@ -185,22 +117,13 @@ def build_graph(provider: str = "groq"):
185
  builder.add_node("tools", ToolNode(tools))
186
  builder.add_edge(START, "retriever")
187
  builder.add_edge("retriever", "assistant")
188
- builder.add_conditional_edges(
189
- "assistant",
190
- tools_condition,
191
- )
192
  builder.add_edge("tools", "assistant")
193
-
194
- # Compile graph
195
  return builder.compile()
196
 
197
- # test
198
  if __name__ == "__main__":
199
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
200
- # Build the graph
201
- graph = build_graph(provider="groq")
202
- # Run the graph
203
- messages = [HumanMessage(content=question)]
204
- messages = graph.invoke({"messages": messages})
205
- for m in messages["messages"]:
206
- m.pretty_print()
 
18
 
19
  load_dotenv()
20
 
21
+ # === TOOLS === #
22
+
23
  @tool
24
+ def multiply(a: int, b: int) -> int: return a * b
 
 
 
 
 
 
25
 
26
  @tool
27
+ def add(a: int, b: int) -> int: return a + b
 
 
 
 
 
 
28
 
29
  @tool
30
+ def subtract(a: int, b: int) -> int: return a - b
 
 
 
 
 
 
31
 
32
  @tool
33
+ def divide(a: int, b: int) -> float:
 
 
 
 
 
34
  if b == 0:
35
  raise ValueError("Cannot divide by zero.")
36
  return a / b
37
 
38
  @tool
39
+ def modulus(a: int, b: int) -> int: return a % b
 
 
 
 
 
 
40
 
41
  @tool
42
  def wiki_search(query: str) -> str:
43
+ docs = WikipediaLoader(query=query, load_max_docs=2).load()
44
+ return {"wiki_results": "\n\n---\n\n".join(doc.page_content for doc in docs)}
 
 
 
 
 
 
 
 
45
 
46
  @tool
47
  def web_search(query: str) -> str:
48
+ docs = TavilySearchResults(max_results=3).invoke(query)
49
+ return {"web_results": "\n\n---\n\n".join(doc.page_content for doc in docs)}
 
 
 
 
 
 
 
 
50
 
51
  @tool
52
  def arvix_search(query: str) -> str:
53
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
54
+ return {"arvix_results": "\n\n---\n\n".join(doc.page_content[:1000] for doc in docs)}
 
 
 
 
 
 
 
 
 
55
 
56
+ # === SYSTEM PROMPT === #
57
 
 
58
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
59
  system_prompt = f.read()
60
 
 
61
  sys_msg = SystemMessage(content=system_prompt)
62
 
63
+ # === EMBEDDING + RETRIEVER === #
64
+
65
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
66
+ supabase: Client = create_client(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY"))
 
67
  vector_store = SupabaseVectorStore(
68
  client=supabase,
69
+ embedding=embeddings,
70
  table_name="Vector_Test",
71
  query_name="match_documents_langchain",
72
  )
73
  create_retriever_tool = create_retriever_tool(
74
  retriever=vector_store.as_retriever(),
75
  name="Question Search",
76
+ description="A tool to retrieve similar questions from a vector store."
77
  )
78
 
79
+ # === TOOL LIST === #
 
80
  tools = [
81
+ multiply, add, subtract, divide, modulus,
82
+ wiki_search, web_search, arvix_search
 
 
 
 
 
 
83
  ]
84
 
85
+ # === BUILD GRAPH === #
86
  def build_graph(provider: str = "groq"):
 
 
87
  if provider == "google":
 
88
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
89
  elif provider == "groq":
90
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
 
91
  elif provider == "huggingface":
92
+ llm = ChatHuggingFace(llm=HuggingFaceEndpoint(
93
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
94
+ temperature=0))
 
 
 
 
95
  else:
96
+ raise ValueError("Invalid provider.")
97
+
98
  llm_with_tools = llm.bind_tools(tools)
99
 
 
100
  def assistant(state: MessagesState):
101
+ response = llm_with_tools.invoke(state["messages"])
102
+ answer = response.content.strip()
103
+ if "FINAL ANSWER:" not in answer:
104
+ answer = f"FINAL ANSWER: {answer.strip().splitlines()[0]}"
105
+ return {"messages": [AIMessage(content=answer)]}
106
 
107
  def retriever(state: MessagesState):
108
+ similar = vector_store.similarity_search(state["messages"][0].content)
109
+ if similar:
110
+ ref = HumanMessage(content=f"Here is a similar example: \n{similar[0].page_content}")
111
+ return {"messages": [sys_msg] + state["messages"] + [ref]}
112
+ return {"messages": [sys_msg] + state["messages"]}
 
113
 
114
  builder = StateGraph(MessagesState)
115
  builder.add_node("retriever", retriever)
 
117
  builder.add_node("tools", ToolNode(tools))
118
  builder.add_edge(START, "retriever")
119
  builder.add_edge("retriever", "assistant")
120
+ builder.add_conditional_edges("assistant", tools_condition)
 
 
 
121
  builder.add_edge("tools", "assistant")
 
 
122
  return builder.compile()
123
 
 
124
  if __name__ == "__main__":
125
+ graph = build_graph()
126
+ question = "What is 12 + 4?"
127
+ result = graph.invoke({"messages": [HumanMessage(content=question)]})
128
+ for m in result["messages"]:
129
+ print(m.content)
 
 
 
app.py CHANGED
@@ -16,7 +16,6 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
16
  # --- Basic Agent Definition ---
17
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
18
 
19
-
20
  cached_answers = []
21
 
22
  class BasicAgent:
@@ -29,8 +28,10 @@ class BasicAgent:
29
  print(f"Agent received question (first 50 chars): {question[:50]}...")
30
  messages = [HumanMessage(content=question)]
31
  messages = self.graph.invoke({"messages": messages})
32
- answer = messages['messages'][-1].content
33
- return answer[14:]
 
 
34
 
35
  def run_agent_only(profile: gr.OAuthProfile | None):
36
  global cached_answers
@@ -45,8 +46,7 @@ def run_agent_only(profile: gr.OAuthProfile | None):
45
  except Exception as e:
46
  return f"Agent Init Error: {e}", None
47
 
48
- api_url = "https://agents-course-unit4-scoring.hf.space"
49
- questions_url = f"{api_url}/questions"
50
 
51
  try:
52
  response = requests.get(questions_url, timeout=15)
@@ -93,8 +93,8 @@ def submit_cached_answers(profile: gr.OAuthProfile | None):
93
  "agent_code": agent_code,
94
  "answers": cached_answers
95
  }
96
-
97
- submit_url = "https://agents-course-unit4-scoring.hf.space/submit"
98
 
99
  try:
100
  response = requests.post(submit_url, json=payload, timeout=60)
 
16
  # --- Basic Agent Definition ---
17
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
18
 
 
19
  cached_answers = []
20
 
21
  class BasicAgent:
 
28
  print(f"Agent received question (first 50 chars): {question[:50]}...")
29
  messages = [HumanMessage(content=question)]
30
  messages = self.graph.invoke({"messages": messages})
31
+ raw_answer = messages['messages'][-1].content
32
+ if raw_answer.startswith("FINAL ANSWER: "):
33
+ return raw_answer[len("FINAL ANSWER: "):].strip()
34
+ return f"Agent response did not follow FINAL ANSWER format: {raw_answer}"
35
 
36
  def run_agent_only(profile: gr.OAuthProfile | None):
37
  global cached_answers
 
46
  except Exception as e:
47
  return f"Agent Init Error: {e}", None
48
 
49
+ questions_url = f"{DEFAULT_API_URL}/questions"
 
50
 
51
  try:
52
  response = requests.get(questions_url, timeout=15)
 
93
  "agent_code": agent_code,
94
  "answers": cached_answers
95
  }
96
+
97
+ submit_url = f"{DEFAULT_API_URL}/submit"
98
 
99
  try:
100
  response = requests.post(submit_url, json=payload, timeout=60)
system_prompt.txt CHANGED
@@ -1,5 +1,18 @@
1
- You are a helpful assistant tasked with answering questions using a set of tools.
2
- Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
3
- FINAL ANSWER: [YOUR FINAL ANSWER].
4
- YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
5
- Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are a helpful assistant answering questions using a set of tools.
2
+
3
+ You must strictly follow this output format:
4
+
5
+ FINAL ANSWER: [YOUR FINAL ANSWER]
6
+
7
+ Where [YOUR FINAL ANSWER] is:
8
+ - A number (e.g., 42) → Do NOT use commas, units ($, %, etc.), or extra words.
9
+ - A string (e.g., Paris) → Do NOT use articles (e.g., "the", "an"), abbreviations, or numeric digits unless required.
10
+ - A comma-separated list (e.g., apple, banana, cherry) → Apply the above rules to each item.
11
+
12
+ Important:
13
+ - Always begin your final output with **exactly** "FINAL ANSWER: ".
14
+ - Do NOT include any reasoning or explanation after your final answer.
15
+ - Do NOT add anything after the period.
16
+ - Think step-by-step internally, but return **only** the FINAL ANSWER line in your output.
17
+
18
+ I will now ask you a question.