File size: 4,078 Bytes
6b8b230 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from typing import Literal
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolNode
from langchain.schema import StrOutputParser
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable.config import RunnableConfig
from langchain_core.messages import HumanMessage
from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import MessagesState
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage
# tools
from langchain_core.tools import tool
from langchain_community.tools.arxiv.tool import ArxivQueryRun
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.tools.yahoo_finance_news import YahooFinanceNewsTool
import chainlit as cl
@tool
def multiply(first_int: int, second_int: int) -> int:
"""Multiply two integers together."""
return first_int * second_int
yahoo_finance_news = YahooFinanceNewsTool()
tavily_search = TavilySearchResults(max_results=5)
arxiv_query = ArxivQueryRun()
tools = [
yahoo_finance_news,
tavily_search,
arxiv_query,
multiply,
]
model = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
final_model = ChatOpenAI(model_name="gpt-4o-mini", temperature=0)
model = model.bind_tools(tools)
# NOTE: this is where we're adding a tag that we'll can use later to filter the model stream events to only the model called in the final node.
# This is not necessary if you call a single LLM but might be important in case you call multiple models within the node and want to filter events
# from only one of them.
final_model = final_model.with_config(tags=["final_node"])
tool_node = ToolNode(tools=tools)
def should_continue(state: MessagesState) -> Literal["tools", "final"]:
messages = state["messages"]
last_message = messages[-1]
# If the LLM makes a tool call, then we route to the "tools" node
if last_message.tool_calls:
return "tools"
# Otherwise, we stop (reply to the user)
return "final"
def call_model(state: MessagesState):
messages = state["messages"]
response = model.invoke(messages)
# We return a list, because this will get merged with the existing list
return {"messages": [response]}
def call_final_model(state: MessagesState):
messages = state["messages"]
last_ai_message = messages[-1]
response = final_model.invoke(
[
SystemMessage("Provide a summary in point form notes of the following:"),
HumanMessage(last_ai_message.content),
]
)
# overwrite the last AI message from the agent
response.id = last_ai_message.id
return {"messages": [response]}
builder = StateGraph(MessagesState)
builder.add_node("agent", call_model)
builder.add_node("tools", tool_node)
# add a separate final node
builder.add_node("final", call_final_model)
builder.add_edge(START, "agent")
builder.add_conditional_edges(
"agent",
should_continue,
)
builder.add_edge("tools", "agent")
builder.add_edge("final", END)
graph = builder.compile()
@cl.on_chat_start
async def on_chat_start():
model = ChatOpenAI(streaming=True)
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You're a very knowledgeable agent with access to several tools to get recent data and multiply numbers.",
),
("human", "{question}"),
]
)
runnable = prompt | model | StrOutputParser()
cl.user_session.set("runnable", runnable)
@cl.on_message
async def on_message(msg: cl.Message):
config = {}
cb = cl.LangchainCallbackHandler()
final_answer = cl.Message(content="")
for msg, metadata in graph.stream({"messages": [HumanMessage(content=msg.content)]}, stream_mode="messages", config=RunnableConfig(callbacks=[cb], **config)):
if (
msg.content
and not isinstance(msg, HumanMessage)
and metadata["langgraph_node"] == "final"
):
await final_answer.stream_token(msg.content)
await final_answer.send()
|