AIxGAIA / agent.py
AC-Angelo93's picture
Update agent.py
6b8a544 verified
raw
history blame
5.66 kB
# agent.py
import os
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
from serpapi import GoogleSearch
# 1️⃣ Switch Graph β†’ StateGraph
from langgraph.graph import StateGraph
from langchain_core.language_models.llms import LLM
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
# ────────────────
# 2️⃣ Load & index your static FAISS docs
# ────────────────
df = pd.read_csv("documents.csv")
DOCS = df["content"].tolist()
EMBEDDER = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
EMBS = EMBEDDER.encode(DOCS, show_progress_bar=True).astype("float32")
INDEX = faiss.IndexFlatL2(EMBS.shape[1])
INDEX.add(EMBS)
# ────────────────
# 3️⃣ Read your system prompt
# ────────────────
with open("system_prompt.txt","r",encoding="utf-8") as f:
SYSTEM_PROMPT = f.read().strip()
# ────────────────
# 4️⃣ Define your tools (unchanged semantics)
# ────────────────
@tool
def calculator(expr: str) -> str:
try:
return str(eval(expr))
except:
return "Error"
@tool
def retrieve_docs(query: str, k: int = 3) -> str:
q_emb = EMBEDDER.encode([query]).astype("float32")
D, I = INDEX.search(q_emb, k)
return "\n\n---\n\n".join(DOCS[i] for i in I[0])
SERPAPI_KEY = os.getenv("SERPAPI_KEY")
@tool
def web_search(query: str, num_results: int = 5) -> str:
params = {"engine":"google","q":query,"num":num_results,"api_key":SERPAPI_KEY}
res = GoogleSearch(params).get_dict().get("organic_results", [])
return "\n".join(f"- {r.get('snippet','')}" for r in res)
@tool
def wiki_search(query: str) -> str:
pages = WikipediaLoader(query=query, load_max_docs=2).load()
return "\n\n---\n\n".join(d.page_content for d in pages)
@tool
def arxiv_search(query: str) -> str:
papers = ArxivLoader(query=query, load_max_docs=3).load()
return "\n\n---\n\n".join(d.page_content[:1000] for d in papers)
# ────────────────
# 5️⃣ Define your State schema
# ────────────────
from typing import TypedDict, List
from langchain_core.messages import BaseMessage
class AgentState(TypedDict):
# We’ll carry a list of messages as our β€œchat history”
messages: List[BaseMessage]
# ────────────────
# 6️⃣ Build the StateGraph
# ────────────────
def build_graph(provider: str = "huggingface") -> StateGraph:
# Instantiate LLM
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise ValueError("HF_TOKEN missing in env")
llm = LLM(provider=provider, token=hf_token, model="meta-llama/Llama-2-7b-chat-hf")
# 6.1) Node: init β†’ seed system prompt
def init_node(_: AgentState) -> AgentState:
return {
"messages": [
SystemMessage(content=SYSTEM_PROMPT)
]
}
# 6.2) Node: human β†’ append user question
def human_node(state: AgentState, question: str) -> AgentState:
state["messages"].append(HumanMessage(content=question))
return state
# 6.3) Node: assistant β†’ call LLM on current messages
def assistant_node(state: AgentState) -> dict:
ai_msg = llm.invoke(state["messages"])
return {"messages": state["messages"] + [ai_msg]}
# 6.4) Optional: tool nodes (they’ll read last HumanMessage)
def make_tool_node(fn):
def tool_node(state: AgentState) -> dict:
# fetch the latest human query
last_query = state["messages"][-1].content
result = fn(last_query)
# append the tool’s output as if from system/Human
state["messages"].append(HumanMessage(content=result))
return {"messages": state["messages"]}
return tool_node
# Instantiate nodes for each tool
calc_node = make_tool_node(calculator)
retrieve_node = make_tool_node(retrieve_docs)
web_node = make_tool_node(web_search)
wiki_node = make_tool_node(wiki_search)
arxiv_node = make_tool_node(arxiv_search)
# 6.5) Build the graph
g = StateGraph(AgentState)
# Register nodes
g.add_node("init", init_node)
g.add_node("human", human_node)
g.add_node("assistant", assistant_node)
g.add_node("calc", calc_node)
g.add_node("retrieve", retrieve_node)
g.add_node("web", web_node)
g.add_node("wiki", wiki_node)
g.add_node("arxiv", arxiv_node)
# Wire up edges
from langgraph.graph import END
g.set_entry_point("init")
# init β†’ human (placeholder: we’ll inject the actual question at runtime)
g.add_edge("init", "human")
# human β†’ assistant
g.add_edge("human", "assistant")
# assistant β†’ tool nodes (conditional on tool calls)
g.add_edge("assistant", "calc")
g.add_edge("assistant", "retrieve")
g.add_edge("assistant", "web")
g.add_edge("assistant", "wiki")
g.add_edge("assistant", "arxiv")
# each tool returns back into assistant for follow‐up
g.add_edge("calc", "assistant")
g.add_edge("retrieve", "assistant")
g.add_edge("web", "assistant")
g.add_edge("wiki", "assistant")
g.add_edge("arxiv", "assistant")
# and finally assistant β†’ END when done
g.add_edge("assistant", END)
return g.compile()