Spaces:
Sleeping
Sleeping
File size: 5,664 Bytes
dfe4f5a 6b8a544 67d8dd6 8d47cbc 6b8a544 af3f4d1 6b8a544 becfda3 8d47cbc 6b8a544 8d47cbc 6b8a544 8d47cbc 6b8a544 1003bb3 88f109e 1003bb3 4d5b045 6b8a544 4d5b045 8d47cbc 6b8a544 8d47cbc 14f3941 7ff1ac7 6b8a544 becfda3 6b8a544 becfda3 6b8a544 becfda3 6b8a544 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# agent.py
import os
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
from serpapi import GoogleSearch
# 1οΈβ£ Switch Graph β StateGraph
from langgraph.graph import StateGraph
from langchain_core.language_models.llms import LLM
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
# ββββββββββββββββ
# 2οΈβ£ Load & index your static FAISS docs
# ββββββββββββββββ
df = pd.read_csv("documents.csv")
DOCS = df["content"].tolist()
EMBEDDER = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
EMBS = EMBEDDER.encode(DOCS, show_progress_bar=True).astype("float32")
INDEX = faiss.IndexFlatL2(EMBS.shape[1])
INDEX.add(EMBS)
# ββββββββββββββββ
# 3οΈβ£ Read your system prompt
# ββββββββββββββββ
with open("system_prompt.txt","r",encoding="utf-8") as f:
SYSTEM_PROMPT = f.read().strip()
# ββββββββββββββββ
# 4οΈβ£ Define your tools (unchanged semantics)
# ββββββββββββββββ
@tool
def calculator(expr: str) -> str:
try:
return str(eval(expr))
except:
return "Error"
@tool
def retrieve_docs(query: str, k: int = 3) -> str:
q_emb = EMBEDDER.encode([query]).astype("float32")
D, I = INDEX.search(q_emb, k)
return "\n\n---\n\n".join(DOCS[i] for i in I[0])
SERPAPI_KEY = os.getenv("SERPAPI_KEY")
@tool
def web_search(query: str, num_results: int = 5) -> str:
params = {"engine":"google","q":query,"num":num_results,"api_key":SERPAPI_KEY}
res = GoogleSearch(params).get_dict().get("organic_results", [])
return "\n".join(f"- {r.get('snippet','')}" for r in res)
@tool
def wiki_search(query: str) -> str:
pages = WikipediaLoader(query=query, load_max_docs=2).load()
return "\n\n---\n\n".join(d.page_content for d in pages)
@tool
def arxiv_search(query: str) -> str:
papers = ArxivLoader(query=query, load_max_docs=3).load()
return "\n\n---\n\n".join(d.page_content[:1000] for d in papers)
# ββββββββββββββββ
# 5οΈβ£ Define your State schema
# ββββββββββββββββ
from typing import TypedDict, List
from langchain_core.messages import BaseMessage
class AgentState(TypedDict):
# Weβll carry a list of messages as our βchat historyβ
messages: List[BaseMessage]
# ββββββββββββββββ
# 6οΈβ£ Build the StateGraph
# ββββββββββββββββ
def build_graph(provider: str = "huggingface") -> StateGraph:
# Instantiate LLM
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise ValueError("HF_TOKEN missing in env")
llm = LLM(provider=provider, token=hf_token, model="meta-llama/Llama-2-7b-chat-hf")
# 6.1) Node: init β seed system prompt
def init_node(_: AgentState) -> AgentState:
return {
"messages": [
SystemMessage(content=SYSTEM_PROMPT)
]
}
# 6.2) Node: human β append user question
def human_node(state: AgentState, question: str) -> AgentState:
state["messages"].append(HumanMessage(content=question))
return state
# 6.3) Node: assistant β call LLM on current messages
def assistant_node(state: AgentState) -> dict:
ai_msg = llm.invoke(state["messages"])
return {"messages": state["messages"] + [ai_msg]}
# 6.4) Optional: tool nodes (theyβll read last HumanMessage)
def make_tool_node(fn):
def tool_node(state: AgentState) -> dict:
# fetch the latest human query
last_query = state["messages"][-1].content
result = fn(last_query)
# append the toolβs output as if from system/Human
state["messages"].append(HumanMessage(content=result))
return {"messages": state["messages"]}
return tool_node
# Instantiate nodes for each tool
calc_node = make_tool_node(calculator)
retrieve_node = make_tool_node(retrieve_docs)
web_node = make_tool_node(web_search)
wiki_node = make_tool_node(wiki_search)
arxiv_node = make_tool_node(arxiv_search)
# 6.5) Build the graph
g = StateGraph(AgentState)
# Register nodes
g.add_node("init", init_node)
g.add_node("human", human_node)
g.add_node("assistant", assistant_node)
g.add_node("calc", calc_node)
g.add_node("retrieve", retrieve_node)
g.add_node("web", web_node)
g.add_node("wiki", wiki_node)
g.add_node("arxiv", arxiv_node)
# Wire up edges
from langgraph.graph import END
g.set_entry_point("init")
# init β human (placeholder: weβll inject the actual question at runtime)
g.add_edge("init", "human")
# human β assistant
g.add_edge("human", "assistant")
# assistant β tool nodes (conditional on tool calls)
g.add_edge("assistant", "calc")
g.add_edge("assistant", "retrieve")
g.add_edge("assistant", "web")
g.add_edge("assistant", "wiki")
g.add_edge("assistant", "arxiv")
# each tool returns back into assistant for followβup
g.add_edge("calc", "assistant")
g.add_edge("retrieve", "assistant")
g.add_edge("web", "assistant")
g.add_edge("wiki", "assistant")
g.add_edge("arxiv", "assistant")
# and finally assistant β END when done
g.add_edge("assistant", END)
return g.compile()
|