"""LangGraph Agent""" import os from dotenv import load_dotenv from langgraph.graph import START, StateGraph, MessagesState from langgraph.prebuilt import tools_condition from langgraph.prebuilt import ToolNode from langchain_core.messages import SystemMessage, HumanMessage from prompts import SYS_PROMPT from tools import tools from retriever import vector_store from langchain_openai import ChatOpenAI load_dotenv() # System message sys_msg = SystemMessage(content=SYS_PROMPT) # Build graph function def build_graph(): """Build the graph""" llm = ChatOpenAI(temperature=0.1, model="gpt-4o", openai_api_key=os.getenv("OPENAI_API_KEY")) # Bind tools to LLM llm_with_tools = llm.bind_tools(tools) # Node def assistant(state: MessagesState): """Assistant node""" return {"messages": [llm_with_tools.invoke(state["messages"])]} def retriever(state: MessagesState): """Retriever node""" similar_question = vector_store.similarity_search(state["messages"][0].content, k=3) similar_question_content = "\n".join([f"{idx+1}. {doc.page_content}" for idx, doc in enumerate(similar_question)]) example_msg = HumanMessage( content=f"Here I provide some similar questions and answer for reference in case you can't find answer from tool result: \n\n{similar_question_content}", ) return {"messages": [sys_msg] + state["messages"] + [example_msg]} builder = StateGraph(MessagesState) builder.add_node("retriever", retriever) builder.add_node("assistant", assistant) builder.add_node("tools", ToolNode(tools)) builder.add_edge(START, "retriever") builder.add_edge("retriever", "assistant") builder.add_conditional_edges( "assistant", tools_condition, ) builder.add_edge("tools", "assistant") # Compile graph return builder.compile() # test if __name__ == "__main__": question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?" # Build the graph graph = build_graph() # Run the graph messages = [HumanMessage(content=question)] messages = graph.invoke({"messages": messages}) answer = messages['messages'][-1].content for m in messages["messages"]: m.pretty_print()