File size: 5,664 Bytes
dfe4f5a
6b8a544
67d8dd6
8d47cbc
 
6b8a544
 
 
 
 
 
af3f4d1
6b8a544
 
becfda3
8d47cbc
6b8a544
 
 
8d47cbc
 
 
 
 
 
 
 
6b8a544
 
 
 
 
8d47cbc
6b8a544
 
 
1003bb3
88f109e
1003bb3
4d5b045
6b8a544
4d5b045
8d47cbc
 
 
 
 
6b8a544
8d47cbc
14f3941
7ff1ac7
 
6b8a544
 
 
becfda3
 
 
 
6b8a544
becfda3
 
6b8a544
becfda3
6b8a544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# agent.py

import os
import pandas as pd
import faiss

from sentence_transformers import SentenceTransformer
from serpapi import GoogleSearch

# 1️⃣ Switch Graph β†’ StateGraph
from langgraph.graph import StateGraph
from langchain_core.language_models.llms import LLM
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader

# ────────────────
# 2️⃣ Load & index your static FAISS docs
# ────────────────
df = pd.read_csv("documents.csv")
DOCS = df["content"].tolist()

EMBEDDER = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
EMBS = EMBEDDER.encode(DOCS, show_progress_bar=True).astype("float32")
INDEX = faiss.IndexFlatL2(EMBS.shape[1])
INDEX.add(EMBS)

# ────────────────
# 3️⃣ Read your system prompt
# ────────────────
with open("system_prompt.txt","r",encoding="utf-8") as f:
    SYSTEM_PROMPT = f.read().strip()

# ────────────────
# 4️⃣ Define your tools (unchanged semantics)
# ────────────────
@tool
def calculator(expr: str) -> str:
    try:
        return str(eval(expr))
    except:
        return "Error"

@tool
def retrieve_docs(query: str, k: int = 3) -> str:
    q_emb = EMBEDDER.encode([query]).astype("float32")
    D, I = INDEX.search(q_emb, k)
    return "\n\n---\n\n".join(DOCS[i] for i in I[0])

SERPAPI_KEY = os.getenv("SERPAPI_KEY")
@tool
def web_search(query: str, num_results: int = 5) -> str:
    params = {"engine":"google","q":query,"num":num_results,"api_key":SERPAPI_KEY}
    res = GoogleSearch(params).get_dict().get("organic_results", [])
    return "\n".join(f"- {r.get('snippet','')}" for r in res)

@tool
def wiki_search(query: str) -> str:
    pages = WikipediaLoader(query=query, load_max_docs=2).load()
    return "\n\n---\n\n".join(d.page_content for d in pages)

@tool
def arxiv_search(query: str) -> str:
    papers = ArxivLoader(query=query, load_max_docs=3).load()
    return "\n\n---\n\n".join(d.page_content[:1000] for d in papers)


# ────────────────
# 5️⃣ Define your State schema
# ────────────────
from typing import TypedDict, List
from langchain_core.messages import BaseMessage

class AgentState(TypedDict):
    # We’ll carry a list of messages as our β€œchat history”
    messages: List[BaseMessage]


# ────────────────
# 6️⃣ Build the StateGraph
# ────────────────
def build_graph(provider: str = "huggingface") -> StateGraph:
    # Instantiate LLM
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        raise ValueError("HF_TOKEN missing in env")
    llm = LLM(provider=provider, token=hf_token, model="meta-llama/Llama-2-7b-chat-hf")

    # 6.1) Node: init β†’ seed system prompt
    def init_node(_: AgentState) -> AgentState:
        return {
            "messages": [
                SystemMessage(content=SYSTEM_PROMPT)
            ]
        }
    
    # 6.2) Node: human β†’ append user question
    def human_node(state: AgentState, question: str) -> AgentState:
        state["messages"].append(HumanMessage(content=question))
        return state

    # 6.3) Node: assistant β†’ call LLM on current messages
    def assistant_node(state: AgentState) -> dict:
        ai_msg = llm.invoke(state["messages"])
        return {"messages": state["messages"] + [ai_msg]}

    # 6.4) Optional: tool nodes (they’ll read last HumanMessage)
    def make_tool_node(fn):
        def tool_node(state: AgentState) -> dict:
            # fetch the latest human query
            last_query = state["messages"][-1].content
            result = fn(last_query)
            # append the tool’s output as if from system/Human
            state["messages"].append(HumanMessage(content=result))
            return {"messages": state["messages"]}
        return tool_node

    # Instantiate nodes for each tool
    calc_node     = make_tool_node(calculator)
    retrieve_node = make_tool_node(retrieve_docs)
    web_node      = make_tool_node(web_search)
    wiki_node     = make_tool_node(wiki_search)
    arxiv_node    = make_tool_node(arxiv_search)

    # 6.5) Build the graph
    g = StateGraph(AgentState)

    # Register nodes
    g.add_node("init",    init_node)
    g.add_node("human",   human_node)
    g.add_node("assistant", assistant_node)
    g.add_node("calc",    calc_node)
    g.add_node("retrieve", retrieve_node)
    g.add_node("web",      web_node)
    g.add_node("wiki",     wiki_node)
    g.add_node("arxiv",    arxiv_node)

    # Wire up edges
    from langgraph.graph import END
    g.set_entry_point("init")
    # init β†’ human (placeholder: we’ll inject the actual question at runtime)
    g.add_edge("init", "human")
    # human β†’ assistant
    g.add_edge("human", "assistant")
    # assistant β†’ tool nodes (conditional on tool calls)
    g.add_edge("assistant", "calc")
    g.add_edge("assistant", "retrieve")
    g.add_edge("assistant", "web")
    g.add_edge("assistant", "wiki")
    g.add_edge("assistant", "arxiv")
    # each tool returns back into assistant for follow‐up
    g.add_edge("calc",     "assistant")
    g.add_edge("retrieve", "assistant")
    g.add_edge("web",      "assistant")
    g.add_edge("wiki",     "assistant")
    g.add_edge("arxiv",    "assistant")
    # and finally assistant β†’ END when done
    g.add_edge("assistant", END)

    return g.compile()