File size: 2,332 Bytes
c58342b
f224484
c58342b
 
 
 
 
 
 
 
 
 
 
 
 
6c5b30f
c58342b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f224484
c58342b
 
 
 
 
 
 
 
 
 
 
f224484
c58342b
f224484
c58342b
 
f224484
c58342b
f224484
c58342b
 
 
 
 
 
6c5b30f
c58342b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""LangGraph Agent"""
import os
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from langchain_core.messages import SystemMessage, HumanMessage
from prompts import SYS_PROMPT
from tools import tools
from retriever import vector_store
from langchain_openai import ChatOpenAI

load_dotenv()



# System message
sys_msg = SystemMessage(content=SYS_PROMPT)


# Build graph function
def build_graph():
    """Build the graph"""
    llm = ChatOpenAI(temperature=0.1, model="gpt-4o", openai_api_key=os.getenv("OPENAI_API_KEY"))
    # Bind tools to LLM
    llm_with_tools = llm.bind_tools(tools)

    # Node
    def assistant(state: MessagesState):
        """Assistant node"""
        return {"messages": [llm_with_tools.invoke(state["messages"])]}
    
    def retriever(state: MessagesState):
        """Retriever node"""
        similar_question = vector_store.similarity_search(state["messages"][0].content, k=3)
        similar_question_content = "\n".join([f"{idx+1}. {doc.page_content}" for idx, doc in enumerate(similar_question)])
        example_msg = HumanMessage(
            content=f"Here I provide some similar questions and answer for reference in case you can't find answer from tool result: \n\n{similar_question_content}",
        )
        return {"messages": [sys_msg] + state["messages"] + [example_msg]}

    builder = StateGraph(MessagesState)
    builder.add_node("retriever", retriever)
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(tools))
    builder.add_edge(START, "retriever")
    builder.add_edge("retriever", "assistant")
    builder.add_conditional_edges(
        "assistant",
        tools_condition,
    )
    builder.add_edge("tools", "assistant")

    # Compile graph
    return builder.compile()

# test
if __name__ == "__main__":
    question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
    # Build the graph
    graph = build_graph()
    # Run the graph
    messages = [HumanMessage(content=question)]
    messages = graph.invoke({"messages": messages})
    answer = messages['messages'][-1].content
    for m in messages["messages"]:
        m.pretty_print()