AC-Angelo93 commited on
Commit
3fde5b7
Β·
verified Β·
1 Parent(s): a89853e

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +57 -5
agent.py CHANGED
@@ -11,7 +11,7 @@ from serpapi import GoogleSearch
11
  from langgraph.graph import StateGraph
12
  from langchain_core.language_models.llms import LLM
13
  from langchain_core.messages import SystemMessage, HumanMessage
14
- from langchain.tools import tool
15
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
16
 
17
  # ────────────────
@@ -32,39 +32,91 @@ with open("system_prompt.txt","r",encoding="utf-8") as f:
32
  SYSTEM_PROMPT = f.read().strip()
33
 
34
  # ────────────────
35
- # 4️⃣ Define your tools (unchanged semantics)
36
  # ────────────────
 
37
  @tool
38
  def calculator(expr: str) -> str:
 
 
 
 
39
  try:
40
  return str(eval(expr))
41
- except:
42
  return "Error"
43
 
 
44
  @tool
45
  def retrieve_docs(query: str, k: int = 3) -> str:
 
 
 
 
 
 
 
 
 
 
46
  q_emb = EMBEDDER.encode([query]).astype("float32")
47
  D, I = INDEX.search(q_emb, k)
48
  return "\n\n---\n\n".join(DOCS[i] for i in I[0])
49
 
 
50
  SERPAPI_KEY = os.getenv("SERPAPI_KEY")
 
51
  @tool
52
  def web_search(query: str, num_results: int = 5) -> str:
53
- params = {"engine":"google","q":query,"num":num_results,"api_key":SERPAPI_KEY}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  res = GoogleSearch(params).get_dict().get("organic_results", [])
55
  return "\n".join(f"- {r.get('snippet','')}" for r in res)
56
 
 
57
  @tool
58
  def wiki_search(query: str) -> str:
 
 
 
 
 
 
 
 
 
59
  pages = WikipediaLoader(query=query, load_max_docs=2).load()
60
  return "\n\n---\n\n".join(d.page_content for d in pages)
61
 
 
62
  @tool
63
  def arxiv_search(query: str) -> str:
 
 
 
 
 
 
 
 
 
64
  papers = ArxivLoader(query=query, load_max_docs=3).load()
65
  return "\n\n---\n\n".join(d.page_content[:1000] for d in papers)
66
 
67
-
68
  # ────────────────
69
  # 5️⃣ Define your State schema
70
  # ────────────────
 
11
  from langgraph.graph import StateGraph
12
  from langchain_core.language_models.llms import LLM
13
  from langchain_core.messages import SystemMessage, HumanMessage
14
+ from langchain_core.tools import tool
15
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
16
 
17
  # ────────────────
 
32
  SYSTEM_PROMPT = f.read().strip()
33
 
34
  # ────────────────
35
+ # 4️⃣ Define your tools (with docstrings)
36
  # ────────────────
37
+
38
  @tool
39
  def calculator(expr: str) -> str:
40
+ """
41
+ Evaluate the given Python expression and return its result as a string.
42
+ Returns "Error" if evaluation fails.
43
+ """
44
  try:
45
  return str(eval(expr))
46
+ except Exception:
47
  return "Error"
48
 
49
+
50
  @tool
51
  def retrieve_docs(query: str, k: int = 3) -> str:
52
+ """
53
+ Perform vector similarity search over the FAISS index.
54
+
55
+ Args:
56
+ query: the user’s query string to embed and search for.
57
+ k: the number of nearest documents to return (default 3).
58
+
59
+ Returns:
60
+ The top-k document contents concatenated into one string.
61
+ """
62
  q_emb = EMBEDDER.encode([query]).astype("float32")
63
  D, I = INDEX.search(q_emb, k)
64
  return "\n\n---\n\n".join(DOCS[i] for i in I[0])
65
 
66
+
67
  SERPAPI_KEY = os.getenv("SERPAPI_KEY")
68
+
69
  @tool
70
  def web_search(query: str, num_results: int = 5) -> str:
71
+ """
72
+ Run a Google search via SerpAPI and return the top snippets.
73
+
74
+ Args:
75
+ query: the search query.
76
+ num_results: how many results to fetch (default 5).
77
+
78
+ Returns:
79
+ A newline-separated list of snippet strings.
80
+ """
81
+ params = {
82
+ "engine": "google",
83
+ "q": query,
84
+ "num": num_results,
85
+ "api_key": SERPAPI_KEY,
86
+ }
87
  res = GoogleSearch(params).get_dict().get("organic_results", [])
88
  return "\n".join(f"- {r.get('snippet','')}" for r in res)
89
 
90
+
91
  @tool
92
  def wiki_search(query: str) -> str:
93
+ """
94
+ Search Wikipedia for up to 2 pages matching `query`.
95
+
96
+ Args:
97
+ query: the topic to look up on Wikipedia.
98
+
99
+ Returns:
100
+ The combined page contents of the top-2 Wikipedia results.
101
+ """
102
  pages = WikipediaLoader(query=query, load_max_docs=2).load()
103
  return "\n\n---\n\n".join(d.page_content for d in pages)
104
 
105
+
106
  @tool
107
  def arxiv_search(query: str) -> str:
108
+ """
109
+ Search ArXiv for up to 3 papers matching `query` and return abstracts.
110
+
111
+ Args:
112
+ query: the search query for ArXiv.
113
+
114
+ Returns:
115
+ The first 1000 characters of each of the top-3 ArXiv abstracts.
116
+ """
117
  papers = ArxivLoader(query=query, load_max_docs=3).load()
118
  return "\n\n---\n\n".join(d.page_content[:1000] for d in papers)
119
 
 
120
  # ────────────────
121
  # 5️⃣ Define your State schema
122
  # ────────────────