from typing import Literal from langchain_openai import ChatOpenAI from langgraph.prebuilt import ToolNode from langchain.schema import StrOutputParser from langchain.prompts import ChatPromptTemplate from langchain.schema.runnable.config import RunnableConfig from langchain_core.messages import HumanMessage from langgraph.graph import END, StateGraph, START from langgraph.graph.message import MessagesState from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage # tools from langchain_core.tools import tool from langchain_community.tools.arxiv.tool import ArxivQueryRun from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.tools.yahoo_finance_news import YahooFinanceNewsTool import chainlit as cl @tool def multiply(first_int: int, second_int: int) -> int: """Multiply two integers together.""" return first_int * second_int yahoo_finance_news = YahooFinanceNewsTool() tavily_search = TavilySearchResults(max_results=5) arxiv_query = ArxivQueryRun() tools = [ yahoo_finance_news, tavily_search, arxiv_query, multiply, ] model = ChatOpenAI(model_name="gpt-4o-mini", temperature=0) final_model = ChatOpenAI(model_name="gpt-4o-mini", temperature=0) model = model.bind_tools(tools) # NOTE: this is where we're adding a tag that we'll can use later to filter the model stream events to only the model called in the final node. # This is not necessary if you call a single LLM but might be important in case you call multiple models within the node and want to filter events # from only one of them. final_model = final_model.with_config(tags=["final_node"]) tool_node = ToolNode(tools=tools) def should_continue(state: MessagesState) -> Literal["tools", "final"]: messages = state["messages"] last_message = messages[-1] # If the LLM makes a tool call, then we route to the "tools" node if last_message.tool_calls: return "tools" # Otherwise, we stop (reply to the user) return "final" def call_model(state: MessagesState): messages = state["messages"] response = model.invoke(messages) # We return a list, because this will get merged with the existing list return {"messages": [response]} def call_final_model(state: MessagesState): messages = state["messages"] last_ai_message = messages[-1] response = final_model.invoke( [ SystemMessage("Provide a summary in point form notes of the following:"), HumanMessage(last_ai_message.content), ] ) # overwrite the last AI message from the agent response.id = last_ai_message.id return {"messages": [response]} builder = StateGraph(MessagesState) builder.add_node("agent", call_model) builder.add_node("tools", tool_node) # add a separate final node builder.add_node("final", call_final_model) builder.add_edge(START, "agent") builder.add_conditional_edges( "agent", should_continue, ) builder.add_edge("tools", "agent") builder.add_edge("final", END) graph = builder.compile() @cl.on_chat_start async def on_chat_start(): model = ChatOpenAI(streaming=True) prompt = ChatPromptTemplate.from_messages( [ ( "system", "You're a very knowledgeable agent with access to several tools to get recent data and multiply numbers.", ), ("human", "{question}"), ] ) runnable = prompt | model | StrOutputParser() cl.user_session.set("runnable", runnable) @cl.on_message async def on_message(msg: cl.Message): config = {} cb = cl.LangchainCallbackHandler() final_answer = cl.Message(content="") for msg, metadata in graph.stream({"messages": [HumanMessage(content=msg.content)]}, stream_mode="messages", config=RunnableConfig(callbacks=[cb], **config)): if ( msg.content and not isinstance(msg, HumanMessage) and metadata["langgraph_node"] == "final" ): await final_answer.stream_token(msg.content) await final_answer.send()