AIxGAIA / agent.py
AC-Angelo93's picture
Update agent.py
72fbc8e verified
# agent.py
import os
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
from serpapi import GoogleSearch
# 1️⃣ Switch Graph β†’ StateGraph
from langgraph.graph import StateGraph
#from langchain_core.language_models.llms import LLM
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
# ────────────────
# 2️⃣ Load & index your static FAISS docs
# ────────────────
df = pd.read_csv("documents.csv")
DOCS = df["content"].tolist()
EMBEDDER = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
EMBS = EMBEDDER.encode(DOCS, show_progress_bar=True).astype("float32")
INDEX = faiss.IndexFlatL2(EMBS.shape[1])
INDEX.add(EMBS)
# ────────────────
# 3️⃣ Read your system prompt
# ────────────────
with open("system_prompt.txt","r",encoding="utf-8") as f:
SYSTEM_PROMPT = f.read().strip()
# ────────────────
# 4️⃣ Define your tools (with docstrings)
# ────────────────
@tool
def calculator(expr: str) -> str:
"""
Evaluate the given Python expression and return its result as a string.
Returns "Error" if evaluation fails.
"""
try:
return str(eval(expr))
except Exception:
return "Error"
@tool
def retrieve_docs(query: str, k: int = 3) -> str:
"""
Perform vector similarity search over the FAISS index.
Args:
query: the user’s query string to embed and search for.
k: the number of nearest documents to return (default 3).
Returns:
The top-k document contents concatenated into one string.
"""
q_emb = EMBEDDER.encode([query]).astype("float32")
D, I = INDEX.search(q_emb, k)
return "\n\n---\n\n".join(DOCS[i] for i in I[0])
SERPAPI_KEY = os.getenv("SERPAPI_KEY")
@tool
def web_search(query: str, num_results: int = 5) -> str:
"""
Run a Google search via SerpAPI and return the top snippets.
Args:
query: the search query.
num_results: how many results to fetch (default 5).
Returns:
A newline-separated list of snippet strings.
"""
params = {
"engine": "google",
"q": query,
"num": num_results,
"api_key": SERPAPI_KEY,
}
res = GoogleSearch(params).get_dict().get("organic_results", [])
return "\n".join(f"- {r.get('snippet','')}" for r in res)
@tool
def wiki_search(query: str) -> str:
"""
Search Wikipedia for up to 2 pages matching `query`.
Args:
query: the topic to look up on Wikipedia.
Returns:
The combined page contents of the top-2 Wikipedia results.
"""
pages = WikipediaLoader(query=query, load_max_docs=2).load()
return "\n\n---\n\n".join(d.page_content for d in pages)
@tool
def arxiv_search(query: str) -> str:
"""
Search ArXiv for up to 3 papers matching `query` and return abstracts.
Args:
query: the search query for ArXiv.
Returns:
The first 1000 characters of each of the top-3 ArXiv abstracts.
"""
papers = ArxivLoader(query=query, load_max_docs=3).load()
return "\n\n---\n\n".join(d.page_content[:1000] for d in papers)
# ────────────────
# 5️⃣ Define your State schema
# ────────────────
from typing import TypedDict, List
from langchain_core.messages import BaseMessage
class AgentState(TypedDict):
# We’ll carry a list of messages as our β€œchat history”
messages: List[BaseMessage]
# ────────────────
# 6️⃣ Build the StateGraph
# ────────────────
GROQ_API_KEY=os.getenv("GROQ_API_KEY")
def build_graph(provider: str = "groq") -> StateGraph:
llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
# 6.1) Node: init β†’ seed system prompt
def init_node(_: AgentState) -> AgentState:
return {
"messages": [
SystemMessage(content=SYSTEM_PROMPT)
]
}
# 6.2) Node: human β†’ stash the GAIA task ID, then append the question
def human_node(state: AgentState, id: str, question: str) -> AgentState:
# keep the GAIA task id so we can submit it later
state["task_id"] = task_id
state["messages"].append(HumanMessage(content=question))
return state
# 6.3) Node: assistant β†’ call LLM on current messages
def assistant_node(state: AgentState) -> dict:
ai_msg = llm.invoke(state["messages"])
return {"messages": state["messages"] + [ai_msg]}
# 6.4) Optional: tool nodes (they’ll read last HumanMessage)
def make_tool_node(fn):
def tool_node(state: AgentState) -> dict:
# fetch the latest human query
last_query = state["messages"][-1].content
result = fn(last_query)
# append the tool’s output as if from system/Human
state["messages"].append(HumanMessage(content=result))
return {"messages": state["messages"]}
return tool_node
# 6.5) Node: answer β†’ pull out the last assistant reply & format submission dict
def answer_node(state: AgentState) -> dict[str,str]:
# the GAIA runner will do `.items()` on whatever you return here
tid = state["task_id"]
# grab the last message (could be a BaseMessage or a raw str)
last = state["messages"][-1]
text = getattr(last, "content", None) or str(last)
return { tid: text }
# Instantiate nodes for each tool
calc_node = make_tool_node(calculator)
retrieve_node = make_tool_node(retrieve_docs)
web_node = make_tool_node(web_search)
wiki_node = make_tool_node(wiki_search)
arxiv_node = make_tool_node(arxiv_search)
# 6.5) Build the graph
g = StateGraph(AgentState)
# Register nodes
g.add_node("init", init_node)
g.add_node("human", human_node)
g.add_node("assistant", assistant_node)
g.add_node("calc", calc_node)
g.add_node("retrieve", retrieve_node)
g.add_node("web", web_node)
g.add_node("wiki", wiki_node)
g.add_node("arxiv", arxiv_node)
# Wire up edges
from langgraph.graph import END
g.set_entry_point("init")
# init β†’ human (placeholder: we’ll inject the actual question at runtime)
g.add_edge("init", "human")
# human β†’ assistant
g.add_edge("human", "assistant")
# assistant β†’ tool nodes (conditional on tool calls)
g.add_edge("assistant", "calc")
g.add_edge("assistant", "retrieve")
g.add_edge("assistant", "web")
g.add_edge("assistant", "wiki")
g.add_edge("assistant", "arxiv")
# each tool returns back into assistant for follow‐up
g.add_edge("calc", "assistant")
g.add_edge("retrieve", "assistant")
g.add_edge("web", "assistant")
g.add_edge("wiki", "assistant")
g.add_edge("arxiv", "assistant")
# register & wire your new answer node
g.add_node("answer", answer_node)
# send assistant β†’ answer β†’ END
g.add_edge("assistant", "answer")
g.add_edge("answer", END)
return g.compile()