Spaces:
Sleeping
Sleeping
File size: 7,582 Bytes
dfe4f5a 788485d 67d8dd6 8d47cbc 788485d 74dc1de b26e0e5 74dc1de 788485d 3fde5b7 becfda3 8d47cbc 788485d 8d47cbc 788485d 8d47cbc 788485d 3fde5b7 788485d 3fde5b7 1003bb3 88f109e 3fde5b7 1003bb3 4d5b045 3fde5b7 4d5b045 8d47cbc 3fde5b7 8d47cbc 3fde5b7 8d47cbc 788485d 8d47cbc 3fde5b7 14f3941 3fde5b7 7ff1ac7 3fde5b7 788485d becfda3 3fde5b7 becfda3 3fde5b7 becfda3 788485d becfda3 3fde5b7 becfda3 788485d 3fde5b7 becfda3 788485d 727e195 b26e0e5 f0aff34 727e195 74dc1de 788485d 942f1db 72fbc8e 788485d 942f1db 788485d 942f1db 788485d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
# agent.py
import os
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
from serpapi import GoogleSearch
# 1οΈβ£ Switch Graph β StateGraph
from langgraph.graph import StateGraph
#from langchain_core.language_models.llms import LLM
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
# ββββββββββββββββ
# 2οΈβ£ Load & index your static FAISS docs
# ββββββββββββββββ
df = pd.read_csv("documents.csv")
DOCS = df["content"].tolist()
EMBEDDER = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
EMBS = EMBEDDER.encode(DOCS, show_progress_bar=True).astype("float32")
INDEX = faiss.IndexFlatL2(EMBS.shape[1])
INDEX.add(EMBS)
# ββββββββββββββββ
# 3οΈβ£ Read your system prompt
# ββββββββββββββββ
with open("system_prompt.txt","r",encoding="utf-8") as f:
SYSTEM_PROMPT = f.read().strip()
# ββββββββββββββββ
# 4οΈβ£ Define your tools (with docstrings)
# ββββββββββββββββ
@tool
def calculator(expr: str) -> str:
"""
Evaluate the given Python expression and return its result as a string.
Returns "Error" if evaluation fails.
"""
try:
return str(eval(expr))
except Exception:
return "Error"
@tool
def retrieve_docs(query: str, k: int = 3) -> str:
"""
Perform vector similarity search over the FAISS index.
Args:
query: the userβs query string to embed and search for.
k: the number of nearest documents to return (default 3).
Returns:
The top-k document contents concatenated into one string.
"""
q_emb = EMBEDDER.encode([query]).astype("float32")
D, I = INDEX.search(q_emb, k)
return "\n\n---\n\n".join(DOCS[i] for i in I[0])
SERPAPI_KEY = os.getenv("SERPAPI_KEY")
@tool
def web_search(query: str, num_results: int = 5) -> str:
"""
Run a Google search via SerpAPI and return the top snippets.
Args:
query: the search query.
num_results: how many results to fetch (default 5).
Returns:
A newline-separated list of snippet strings.
"""
params = {
"engine": "google",
"q": query,
"num": num_results,
"api_key": SERPAPI_KEY,
}
res = GoogleSearch(params).get_dict().get("organic_results", [])
return "\n".join(f"- {r.get('snippet','')}" for r in res)
@tool
def wiki_search(query: str) -> str:
"""
Search Wikipedia for up to 2 pages matching `query`.
Args:
query: the topic to look up on Wikipedia.
Returns:
The combined page contents of the top-2 Wikipedia results.
"""
pages = WikipediaLoader(query=query, load_max_docs=2).load()
return "\n\n---\n\n".join(d.page_content for d in pages)
@tool
def arxiv_search(query: str) -> str:
"""
Search ArXiv for up to 3 papers matching `query` and return abstracts.
Args:
query: the search query for ArXiv.
Returns:
The first 1000 characters of each of the top-3 ArXiv abstracts.
"""
papers = ArxivLoader(query=query, load_max_docs=3).load()
return "\n\n---\n\n".join(d.page_content[:1000] for d in papers)
# ββββββββββββββββ
# 5οΈβ£ Define your State schema
# ββββββββββββββββ
from typing import TypedDict, List
from langchain_core.messages import BaseMessage
class AgentState(TypedDict):
# Weβll carry a list of messages as our βchat historyβ
messages: List[BaseMessage]
# ββββββββββββββββ
# 6οΈβ£ Build the StateGraph
# ββββββββββββββββ
GROQ_API_KEY=os.getenv("GROQ_API_KEY")
def build_graph(provider: str = "groq") -> StateGraph:
llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
# 6.1) Node: init β seed system prompt
def init_node(_: AgentState) -> AgentState:
return {
"messages": [
SystemMessage(content=SYSTEM_PROMPT)
]
}
# 6.2) Node: human β stash the GAIA task ID, then append the question
def human_node(state: AgentState, id: str, question: str) -> AgentState:
# keep the GAIA task id so we can submit it later
state["task_id"] = task_id
state["messages"].append(HumanMessage(content=question))
return state
# 6.3) Node: assistant β call LLM on current messages
def assistant_node(state: AgentState) -> dict:
ai_msg = llm.invoke(state["messages"])
return {"messages": state["messages"] + [ai_msg]}
# 6.4) Optional: tool nodes (theyβll read last HumanMessage)
def make_tool_node(fn):
def tool_node(state: AgentState) -> dict:
# fetch the latest human query
last_query = state["messages"][-1].content
result = fn(last_query)
# append the toolβs output as if from system/Human
state["messages"].append(HumanMessage(content=result))
return {"messages": state["messages"]}
return tool_node
# 6.5) Node: answer β pull out the last assistant reply & format submission dict
def answer_node(state: AgentState) -> dict[str,str]:
# the GAIA runner will do `.items()` on whatever you return here
tid = state["task_id"]
# grab the last message (could be a BaseMessage or a raw str)
last = state["messages"][-1]
text = getattr(last, "content", None) or str(last)
return { tid: text }
# Instantiate nodes for each tool
calc_node = make_tool_node(calculator)
retrieve_node = make_tool_node(retrieve_docs)
web_node = make_tool_node(web_search)
wiki_node = make_tool_node(wiki_search)
arxiv_node = make_tool_node(arxiv_search)
# 6.5) Build the graph
g = StateGraph(AgentState)
# Register nodes
g.add_node("init", init_node)
g.add_node("human", human_node)
g.add_node("assistant", assistant_node)
g.add_node("calc", calc_node)
g.add_node("retrieve", retrieve_node)
g.add_node("web", web_node)
g.add_node("wiki", wiki_node)
g.add_node("arxiv", arxiv_node)
# Wire up edges
from langgraph.graph import END
g.set_entry_point("init")
# init β human (placeholder: weβll inject the actual question at runtime)
g.add_edge("init", "human")
# human β assistant
g.add_edge("human", "assistant")
# assistant β tool nodes (conditional on tool calls)
g.add_edge("assistant", "calc")
g.add_edge("assistant", "retrieve")
g.add_edge("assistant", "web")
g.add_edge("assistant", "wiki")
g.add_edge("assistant", "arxiv")
# each tool returns back into assistant for followβup
g.add_edge("calc", "assistant")
g.add_edge("retrieve", "assistant")
g.add_edge("web", "assistant")
g.add_edge("wiki", "assistant")
g.add_edge("arxiv", "assistant")
# register & wire your new answer node
g.add_node("answer", answer_node)
# send assistant β answer β END
g.add_edge("assistant", "answer")
g.add_edge("answer", END)
return g.compile()
|