Spaces:
Sleeping
Sleeping
# agent.py | |
import os | |
import pandas as pd | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
from serpapi import GoogleSearch | |
# 1οΈβ£ Switch Graph β StateGraph | |
from langgraph.graph import StateGraph | |
from langchain_core.language_models.llms import LLM | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from langchain_core.tools import tool | |
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader | |
# ββββββββββββββββ | |
# 2οΈβ£ Load & index your static FAISS docs | |
# ββββββββββββββββ | |
df = pd.read_csv("documents.csv") | |
DOCS = df["content"].tolist() | |
EMBEDDER = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") | |
EMBS = EMBEDDER.encode(DOCS, show_progress_bar=True).astype("float32") | |
INDEX = faiss.IndexFlatL2(EMBS.shape[1]) | |
INDEX.add(EMBS) | |
# ββββββββββββββββ | |
# 3οΈβ£ Read your system prompt | |
# ββββββββββββββββ | |
with open("system_prompt.txt","r",encoding="utf-8") as f: | |
SYSTEM_PROMPT = f.read().strip() | |
# ββββββββββββββββ | |
# 4οΈβ£ Define your tools (unchanged semantics) | |
# ββββββββββββββββ | |
def calculator(expr: str) -> str: | |
try: | |
return str(eval(expr)) | |
except: | |
return "Error" | |
def retrieve_docs(query: str, k: int = 3) -> str: | |
q_emb = EMBEDDER.encode([query]).astype("float32") | |
D, I = INDEX.search(q_emb, k) | |
return "\n\n---\n\n".join(DOCS[i] for i in I[0]) | |
SERPAPI_KEY = os.getenv("SERPAPI_KEY") | |
def web_search(query: str, num_results: int = 5) -> str: | |
params = {"engine":"google","q":query,"num":num_results,"api_key":SERPAPI_KEY} | |
res = GoogleSearch(params).get_dict().get("organic_results", []) | |
return "\n".join(f"- {r.get('snippet','')}" for r in res) | |
def wiki_search(query: str) -> str: | |
pages = WikipediaLoader(query=query, load_max_docs=2).load() | |
return "\n\n---\n\n".join(d.page_content for d in pages) | |
def arxiv_search(query: str) -> str: | |
papers = ArxivLoader(query=query, load_max_docs=3).load() | |
return "\n\n---\n\n".join(d.page_content[:1000] for d in papers) | |
# ββββββββββββββββ | |
# 5οΈβ£ Define your State schema | |
# ββββββββββββββββ | |
from typing import TypedDict, List | |
from langchain_core.messages import BaseMessage | |
class AgentState(TypedDict): | |
# Weβll carry a list of messages as our βchat historyβ | |
messages: List[BaseMessage] | |
# ββββββββββββββββ | |
# 6οΈβ£ Build the StateGraph | |
# ββββββββββββββββ | |
def build_graph(provider: str = "huggingface") -> StateGraph: | |
# Instantiate LLM | |
hf_token = os.getenv("HF_TOKEN") | |
if not hf_token: | |
raise ValueError("HF_TOKEN missing in env") | |
llm = LLM(provider=provider, token=hf_token, model="meta-llama/Llama-2-7b-chat-hf") | |
# 6.1) Node: init β seed system prompt | |
def init_node(_: AgentState) -> AgentState: | |
return { | |
"messages": [ | |
SystemMessage(content=SYSTEM_PROMPT) | |
] | |
} | |
# 6.2) Node: human β append user question | |
def human_node(state: AgentState, question: str) -> AgentState: | |
state["messages"].append(HumanMessage(content=question)) | |
return state | |
# 6.3) Node: assistant β call LLM on current messages | |
def assistant_node(state: AgentState) -> dict: | |
ai_msg = llm.invoke(state["messages"]) | |
return {"messages": state["messages"] + [ai_msg]} | |
# 6.4) Optional: tool nodes (theyβll read last HumanMessage) | |
def make_tool_node(fn): | |
def tool_node(state: AgentState) -> dict: | |
# fetch the latest human query | |
last_query = state["messages"][-1].content | |
result = fn(last_query) | |
# append the toolβs output as if from system/Human | |
state["messages"].append(HumanMessage(content=result)) | |
return {"messages": state["messages"]} | |
return tool_node | |
# Instantiate nodes for each tool | |
calc_node = make_tool_node(calculator) | |
retrieve_node = make_tool_node(retrieve_docs) | |
web_node = make_tool_node(web_search) | |
wiki_node = make_tool_node(wiki_search) | |
arxiv_node = make_tool_node(arxiv_search) | |
# 6.5) Build the graph | |
g = StateGraph(AgentState) | |
# Register nodes | |
g.add_node("init", init_node) | |
g.add_node("human", human_node) | |
g.add_node("assistant", assistant_node) | |
g.add_node("calc", calc_node) | |
g.add_node("retrieve", retrieve_node) | |
g.add_node("web", web_node) | |
g.add_node("wiki", wiki_node) | |
g.add_node("arxiv", arxiv_node) | |
# Wire up edges | |
from langgraph.graph import END | |
g.set_entry_point("init") | |
# init β human (placeholder: weβll inject the actual question at runtime) | |
g.add_edge("init", "human") | |
# human β assistant | |
g.add_edge("human", "assistant") | |
# assistant β tool nodes (conditional on tool calls) | |
g.add_edge("assistant", "calc") | |
g.add_edge("assistant", "retrieve") | |
g.add_edge("assistant", "web") | |
g.add_edge("assistant", "wiki") | |
g.add_edge("assistant", "arxiv") | |
# each tool returns back into assistant for followβup | |
g.add_edge("calc", "assistant") | |
g.add_edge("retrieve", "assistant") | |
g.add_edge("web", "assistant") | |
g.add_edge("wiki", "assistant") | |
g.add_edge("arxiv", "assistant") | |
# and finally assistant β END when done | |
g.add_edge("assistant", END) | |
return g.compile() | |