Files changed (1) hide show
  1. agent.py +123 -134
agent.py CHANGED
@@ -1,171 +1,160 @@
1
  # agent.py
 
2
  import os
3
- #from supabase import create_client
4
- from sentence_transformers import SentenceTransformer
5
- from serpapi import GoogleSearch
6
  import pandas as pd
7
  import faiss
8
- from langgraph.graph import Graph
 
 
 
 
 
9
  from langchain_core.language_models.llms import LLM
10
- from langchain_core.tools import tool
 
11
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
12
 
13
- # ─── 1) Load & embed all documents at startup ───
14
- # 1a) Read CSV of docs
 
15
  df = pd.read_csv("documents.csv")
16
  DOCS = df["content"].tolist()
17
 
18
- # 1b) Create an embedding model
19
  EMBEDDER = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
20
-
21
- # 1c) Compute embeddings (float32) and build FAISS index
22
  EMBS = EMBEDDER.encode(DOCS, show_progress_bar=True).astype("float32")
23
  INDEX = faiss.IndexFlatL2(EMBS.shape[1])
24
  INDEX.add(EMBS)
25
 
 
 
 
 
 
26
 
27
- # ----Supabase setup----
28
- SUPABASE_URL = os.getenv("SUPABASE_URL")
29
- SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_KEY")
30
- EMBED_MODEL_ID = os.getenv("HF_EMBEDDING_MODEL")
31
-
32
-
33
- #sb_client = create_client(SUPABASE_URL, SUPABASE_KEY)
34
- #embedder = SentenceTransformer(EMBED_MODEL_ID)
35
-
36
- # 1) Define tools
37
-
38
  @tool
39
  def calculator(expr: str) -> str:
40
- """Simple math via Python eval"""
41
  try:
42
  return str(eval(expr))
43
- except Exception:
44
  return "Error"
45
- # e.g. search, vector_retrieval, etc.
46
- # @tool
47
- # def web_search(query:str) -> str:
48
- # ...
49
- #@tool
50
- #def retrieve_docs(query: str, k: int = 3) -> str:
51
- #"""
52
- #Fetch tpo-k docs from Supabase vector store.
53
- #Returns the concatenated text.
54
- #"""
55
- # --- embed the query
56
- #q_emb = embedder.encode(query).tolist()
57
-
58
- # --- query the embedding table
59
- #response = (
60
- # sb_client
61
- # .rpc("match_documents", {"query_embedding": q_emb, "match_count": k})
62
- # .execute()
63
- # )
64
- # rows = response.data
65
-
66
- # ---- concatenate the content field
67
- # docs = [row["content"] for row in rows]
68
- # return "\n\n---\n\n".join(docs)
69
 
70
  @tool
71
  def retrieve_docs(query: str, k: int = 3) -> str:
72
- """
73
- k-NN search over our in-memory FAISS index.
74
- Returns the top-k documents concatenated.
75
- """
76
- # 1) Embed the query
77
  q_emb = EMBEDDER.encode([query]).astype("float32")
78
- # 2) Search FAISS
79
  D, I = INDEX.search(q_emb, k)
80
- # 3) Gather and return the texts
81
- hits = [DOCS[i] for i in I[0]]
82
- return "\n\n---\n\n".join(hits)
83
 
84
-
85
  SERPAPI_KEY = os.getenv("SERPAPI_KEY")
86
- # ---- web_search tool
87
  @tool
88
  def web_search(query: str, num_results: int = 5) -> str:
89
- """ Return top-5 snippets from Google search via SerpAPI."""
90
- params = {
91
- "engine": "google",
92
- "q": query,
93
- "num": num_results,
94
- "api_key": SERPAPI_KEY,
95
- }
96
- search = GoogleSearch(params)
97
- results = search.get_dict().get("organic_results", [])
98
- snippets = [r.get("snippet","")for r in results]
99
- return "\n".join(f"- {s}" for s in snippets)
100
 
101
  @tool
102
  def wiki_search(query: str) -> str:
103
- """
104
- Search Wikipedia for up to 2 pages matching 'query',
105
- and return their contents.
106
- """
107
- #load up to 2 pages
108
  pages = WikipediaLoader(query=query, load_max_docs=2).load()
109
- #format as plain text
110
- return "\n\n---\n\n".join(doc.page_content for doc in pages)
111
 
112
  @tool
113
- def arxiv_search(query:str) -> str:
114
- """
115
- Search ArXiv for up to 3 abstracts matching 'query',
116
- and return their first 1000 characters.
117
- """
118
  papers = ArxivLoader(query=query, load_max_docs=3).load()
119
- return "\n\n---\n\n".join(doc.page_content[:1000]for doc in papers)
120
-
121
-
122
- #read the system prompt
123
- with open("system_prompt.txt","r",encoding="utf-8") as f:
124
- SYSTEM_PROMPT = f.read()
125
-
126
- # 2) Build your graph
127
- def build_graph(provider: str = "huggingface") -> Graph:
128
- # 2a) Instantiate your LLM endpoint
129
- api_token = os.getenv("HF_TOKEN")
130
- if not api_token:
131
- raise ValueError("HF_TOKEN not found: please add it under Settings β†’ Secrets and variables in your Space")
132
- llm = LLM(provider=provider, token=api_token, model="meta-llama/Llama-2-7b-chat-hf")
133
-
134
- # 2b) Attach tools
135
- tools = [
136
- calculator,
137
- retrieve_docs,
138
- web_search,
139
- wiki_search,
140
- arxiv_search, # add more tools here
141
- ]
142
- llm_with_tools = llm.bind_tools(tools)
143
-
144
- # 2c) Compose your graph
145
- graph = Graph()
146
- # sys node: prepend system prompt
147
- def _prepend_system(query: str) -> str:
148
- return SYSTEM_PROMPT + "\n\n" + query
149
- graph.add_node("sys", _prepend_system)
150
- # "ask" node : the LLM itself
151
- graph.add_node("ask", llm_with_tools) # prompt node
152
- # tool nodes
153
- graph.add_node("calc", calculator)
154
- graph.add_node("retrieve", retrieve_docs)
155
- graph.add_node("web_search", web_search)
156
- graph.add_node("wiki", wiki_search)
157
- graph.add_node("arxiv", arxiv_search)
158
- # allow the LLM to call any tool:
159
- graph.add_edge("ask", "calc") # allow ask -> calc
160
- graph.add_edge("ask", "retrieve")
161
- graph.add_edge("ask", "web_search")
162
- graph.add_edge("ask", "wiki")
163
- graph.add_edge("ask", "arxiv")
164
-
165
- # wire up the start: sys -> ask
166
- graph.add_edge(Graph.START, "sys")
167
- graph.add_edge("sys", "ask")
168
- graph.set_start("ask")
169
-
170
- return graph
171
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # agent.py
2
+
3
  import os
 
 
 
4
  import pandas as pd
5
  import faiss
6
+
7
+ from sentence_transformers import SentenceTransformer
8
+ from serpapi import GoogleSearch
9
+
10
+ # 1️⃣ Switch Graph β†’ StateGraph
11
+ from langgraph.graph import StateGraph
12
  from langchain_core.language_models.llms import LLM
13
+ from langchain_core.messages import SystemMessage, HumanMessage
14
+ from langchain_core.tools import tool
15
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
16
 
17
+ # ────────────────
18
+ # 2️⃣ Load & index your static FAISS docs
19
+ # ────────────────
20
  df = pd.read_csv("documents.csv")
21
  DOCS = df["content"].tolist()
22
 
 
23
  EMBEDDER = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
 
 
24
  EMBS = EMBEDDER.encode(DOCS, show_progress_bar=True).astype("float32")
25
  INDEX = faiss.IndexFlatL2(EMBS.shape[1])
26
  INDEX.add(EMBS)
27
 
28
+ # ────────────────
29
+ # 3️⃣ Read your system prompt
30
+ # ────────────────
31
+ with open("system_prompt.txt","r",encoding="utf-8") as f:
32
+ SYSTEM_PROMPT = f.read().strip()
33
 
34
+ # ────────────────
35
+ # 4️⃣ Define your tools (unchanged semantics)
36
+ # ────────────────
 
 
 
 
 
 
 
 
37
  @tool
38
  def calculator(expr: str) -> str:
 
39
  try:
40
  return str(eval(expr))
41
+ except:
42
  return "Error"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  @tool
45
  def retrieve_docs(query: str, k: int = 3) -> str:
 
 
 
 
 
46
  q_emb = EMBEDDER.encode([query]).astype("float32")
 
47
  D, I = INDEX.search(q_emb, k)
48
+ return "\n\n---\n\n".join(DOCS[i] for i in I[0])
 
 
49
 
 
50
  SERPAPI_KEY = os.getenv("SERPAPI_KEY")
 
51
  @tool
52
  def web_search(query: str, num_results: int = 5) -> str:
53
+ params = {"engine":"google","q":query,"num":num_results,"api_key":SERPAPI_KEY}
54
+ res = GoogleSearch(params).get_dict().get("organic_results", [])
55
+ return "\n".join(f"- {r.get('snippet','')}" for r in res)
 
 
 
 
 
 
 
 
56
 
57
  @tool
58
  def wiki_search(query: str) -> str:
 
 
 
 
 
59
  pages = WikipediaLoader(query=query, load_max_docs=2).load()
60
+ return "\n\n---\n\n".join(d.page_content for d in pages)
 
61
 
62
  @tool
63
+ def arxiv_search(query: str) -> str:
 
 
 
 
64
  papers = ArxivLoader(query=query, load_max_docs=3).load()
65
+ return "\n\n---\n\n".join(d.page_content[:1000] for d in papers)
66
+
67
+
68
+ # ────────────────
69
+ # 5️⃣ Define your State schema
70
+ # ────────────────
71
+ from typing import TypedDict, List
72
+ from langchain_core.messages import BaseMessage
73
+
74
+ class AgentState(TypedDict):
75
+ # We’ll carry a list of messages as our β€œchat history”
76
+ messages: List[BaseMessage]
77
+
78
+
79
+ # ────────────────
80
+ # 6️⃣ Build the StateGraph
81
+ # ────────────────
82
+ def build_graph(provider: str = "huggingface") -> StateGraph:
83
+ # Instantiate LLM
84
+ hf_token = os.getenv("HF_TOKEN")
85
+ if not hf_token:
86
+ raise ValueError("HF_TOKEN missing in env")
87
+ llm = LLM(provider=provider, token=hf_token, model="meta-llama/Llama-2-7b-chat-hf")
88
+
89
+ # 6.1) Node: init β†’ seed system prompt
90
+ def init_node(_: AgentState) -> AgentState:
91
+ return {
92
+ "messages": [
93
+ SystemMessage(content=SYSTEM_PROMPT)
94
+ ]
95
+ }
96
+
97
+ # 6.2) Node: human β†’ append user question
98
+ def human_node(state: AgentState, question: str) -> AgentState:
99
+ state["messages"].append(HumanMessage(content=question))
100
+ return state
101
+
102
+ # 6.3) Node: assistant β†’ call LLM on current messages
103
+ def assistant_node(state: AgentState) -> dict:
104
+ ai_msg = llm.invoke(state["messages"])
105
+ return {"messages": state["messages"] + [ai_msg]}
106
+
107
+ # 6.4) Optional: tool nodes (they’ll read last HumanMessage)
108
+ def make_tool_node(fn):
109
+ def tool_node(state: AgentState) -> dict:
110
+ # fetch the latest human query
111
+ last_query = state["messages"][-1].content
112
+ result = fn(last_query)
113
+ # append the tool’s output as if from system/Human
114
+ state["messages"].append(HumanMessage(content=result))
115
+ return {"messages": state["messages"]}
116
+ return tool_node
117
+
118
+ # Instantiate nodes for each tool
119
+ calc_node = make_tool_node(calculator)
120
+ retrieve_node = make_tool_node(retrieve_docs)
121
+ web_node = make_tool_node(web_search)
122
+ wiki_node = make_tool_node(wiki_search)
123
+ arxiv_node = make_tool_node(arxiv_search)
124
+
125
+ # 6.5) Build the graph
126
+ g = StateGraph(AgentState)
127
+
128
+ # Register nodes
129
+ g.add_node("init", init_node)
130
+ g.add_node("human", human_node)
131
+ g.add_node("assistant", assistant_node)
132
+ g.add_node("calc", calc_node)
133
+ g.add_node("retrieve", retrieve_node)
134
+ g.add_node("web", web_node)
135
+ g.add_node("wiki", wiki_node)
136
+ g.add_node("arxiv", arxiv_node)
137
+
138
+ # Wire up edges
139
+ from langgraph.graph import END
140
+ g.set_entry_point("init")
141
+ # init β†’ human (placeholder: we’ll inject the actual question at runtime)
142
+ g.add_edge("init", "human")
143
+ # human β†’ assistant
144
+ g.add_edge("human", "assistant")
145
+ # assistant β†’ tool nodes (conditional on tool calls)
146
+ g.add_edge("assistant", "calc")
147
+ g.add_edge("assistant", "retrieve")
148
+ g.add_edge("assistant", "web")
149
+ g.add_edge("assistant", "wiki")
150
+ g.add_edge("assistant", "arxiv")
151
+ # each tool returns back into assistant for follow‐up
152
+ g.add_edge("calc", "assistant")
153
+ g.add_edge("retrieve", "assistant")
154
+ g.add_edge("web", "assistant")
155
+ g.add_edge("wiki", "assistant")
156
+ g.add_edge("arxiv", "assistant")
157
+ # and finally assistant β†’ END when done
158
+ g.add_edge("assistant", END)
159
+
160
+ return g.compile()