Spaces:
Sleeping
Sleeping
# agent.py | |
import os | |
import pandas as pd | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
from serpapi import GoogleSearch | |
# 1οΈβ£ Switch Graph β StateGraph | |
from langgraph.graph import StateGraph | |
#from langchain_core.language_models.llms import LLM | |
from langchain_groq import ChatGroq | |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from langchain_core.tools import tool | |
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader | |
# ββββββββββββββββ | |
# 2οΈβ£ Load & index your static FAISS docs | |
# ββββββββββββββββ | |
df = pd.read_csv("documents.csv") | |
DOCS = df["content"].tolist() | |
EMBEDDER = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") | |
EMBS = EMBEDDER.encode(DOCS, show_progress_bar=True).astype("float32") | |
INDEX = faiss.IndexFlatL2(EMBS.shape[1]) | |
INDEX.add(EMBS) | |
# ββββββββββββββββ | |
# 3οΈβ£ Read your system prompt | |
# ββββββββββββββββ | |
with open("system_prompt.txt","r",encoding="utf-8") as f: | |
SYSTEM_PROMPT = f.read().strip() | |
# ββββββββββββββββ | |
# 4οΈβ£ Define your tools (with docstrings) | |
# ββββββββββββββββ | |
def calculator(expr: str) -> str: | |
""" | |
Evaluate the given Python expression and return its result as a string. | |
Returns "Error" if evaluation fails. | |
""" | |
try: | |
return str(eval(expr)) | |
except Exception: | |
return "Error" | |
def retrieve_docs(query: str, k: int = 3) -> str: | |
""" | |
Perform vector similarity search over the FAISS index. | |
Args: | |
query: the userβs query string to embed and search for. | |
k: the number of nearest documents to return (default 3). | |
Returns: | |
The top-k document contents concatenated into one string. | |
""" | |
q_emb = EMBEDDER.encode([query]).astype("float32") | |
D, I = INDEX.search(q_emb, k) | |
return "\n\n---\n\n".join(DOCS[i] for i in I[0]) | |
SERPAPI_KEY = os.getenv("SERPAPI_KEY") | |
def web_search(query: str, num_results: int = 5) -> str: | |
""" | |
Run a Google search via SerpAPI and return the top snippets. | |
Args: | |
query: the search query. | |
num_results: how many results to fetch (default 5). | |
Returns: | |
A newline-separated list of snippet strings. | |
""" | |
params = { | |
"engine": "google", | |
"q": query, | |
"num": num_results, | |
"api_key": SERPAPI_KEY, | |
} | |
res = GoogleSearch(params).get_dict().get("organic_results", []) | |
return "\n".join(f"- {r.get('snippet','')}" for r in res) | |
def wiki_search(query: str) -> str: | |
""" | |
Search Wikipedia for up to 2 pages matching `query`. | |
Args: | |
query: the topic to look up on Wikipedia. | |
Returns: | |
The combined page contents of the top-2 Wikipedia results. | |
""" | |
pages = WikipediaLoader(query=query, load_max_docs=2).load() | |
return "\n\n---\n\n".join(d.page_content for d in pages) | |
def arxiv_search(query: str) -> str: | |
""" | |
Search ArXiv for up to 3 papers matching `query` and return abstracts. | |
Args: | |
query: the search query for ArXiv. | |
Returns: | |
The first 1000 characters of each of the top-3 ArXiv abstracts. | |
""" | |
papers = ArxivLoader(query=query, load_max_docs=3).load() | |
return "\n\n---\n\n".join(d.page_content[:1000] for d in papers) | |
# ββββββββββββββββ | |
# 5οΈβ£ Define your State schema | |
# ββββββββββββββββ | |
from typing import TypedDict, List | |
from langchain_core.messages import BaseMessage | |
class AgentState(TypedDict): | |
# Weβll carry a list of messages as our βchat historyβ | |
messages: List[BaseMessage] | |
# ββββββββββββββββ | |
# 6οΈβ£ Build the StateGraph | |
# ββββββββββββββββ | |
GROQ_API_KEY=os.getenv("GROQ_API_KEY") | |
def build_graph(provider: str = "groq") -> StateGraph: | |
llm = ChatGroq(model="qwen-qwq-32b", temperature=0) | |
# 6.1) Node: init β seed system prompt | |
def init_node(_: AgentState) -> AgentState: | |
return { | |
"messages": [ | |
SystemMessage(content=SYSTEM_PROMPT) | |
] | |
} | |
# 6.2) Node: human β stash the GAIA task ID, then append the question | |
def human_node(state: AgentState, id: str, question: str) -> AgentState: | |
# keep the GAIA task id so we can submit it later | |
state["task_id"] = task_id | |
state["messages"].append(HumanMessage(content=question)) | |
return state | |
# 6.3) Node: assistant β call LLM on current messages | |
def assistant_node(state: AgentState) -> dict: | |
ai_msg = llm.invoke(state["messages"]) | |
return {"messages": state["messages"] + [ai_msg]} | |
# 6.4) Optional: tool nodes (theyβll read last HumanMessage) | |
def make_tool_node(fn): | |
def tool_node(state: AgentState) -> dict: | |
# fetch the latest human query | |
last_query = state["messages"][-1].content | |
result = fn(last_query) | |
# append the toolβs output as if from system/Human | |
state["messages"].append(HumanMessage(content=result)) | |
return {"messages": state["messages"]} | |
return tool_node | |
# 6.5) Node: answer β pull out the last assistant reply & format submission dict | |
def answer_node(state: AgentState) -> dict[str,str]: | |
# the GAIA runner will do `.items()` on whatever you return here | |
tid = state["task_id"] | |
# grab the last message (could be a BaseMessage or a raw str) | |
last = state["messages"][-1] | |
text = getattr(last, "content", None) or str(last) | |
return { tid: text } | |
# Instantiate nodes for each tool | |
calc_node = make_tool_node(calculator) | |
retrieve_node = make_tool_node(retrieve_docs) | |
web_node = make_tool_node(web_search) | |
wiki_node = make_tool_node(wiki_search) | |
arxiv_node = make_tool_node(arxiv_search) | |
# 6.5) Build the graph | |
g = StateGraph(AgentState) | |
# Register nodes | |
g.add_node("init", init_node) | |
g.add_node("human", human_node) | |
g.add_node("assistant", assistant_node) | |
g.add_node("calc", calc_node) | |
g.add_node("retrieve", retrieve_node) | |
g.add_node("web", web_node) | |
g.add_node("wiki", wiki_node) | |
g.add_node("arxiv", arxiv_node) | |
# Wire up edges | |
from langgraph.graph import END | |
g.set_entry_point("init") | |
# init β human (placeholder: weβll inject the actual question at runtime) | |
g.add_edge("init", "human") | |
# human β assistant | |
g.add_edge("human", "assistant") | |
# assistant β tool nodes (conditional on tool calls) | |
g.add_edge("assistant", "calc") | |
g.add_edge("assistant", "retrieve") | |
g.add_edge("assistant", "web") | |
g.add_edge("assistant", "wiki") | |
g.add_edge("assistant", "arxiv") | |
# each tool returns back into assistant for followβup | |
g.add_edge("calc", "assistant") | |
g.add_edge("retrieve", "assistant") | |
g.add_edge("web", "assistant") | |
g.add_edge("wiki", "assistant") | |
g.add_edge("arxiv", "assistant") | |
# register & wire your new answer node | |
g.add_node("answer", answer_node) | |
# send assistant β answer β END | |
g.add_edge("assistant", "answer") | |
g.add_edge("answer", END) | |
return g.compile() | |