|
from typing import Literal |
|
from langchain_openai import ChatOpenAI |
|
from langgraph.prebuilt import ToolNode |
|
from langchain.schema import StrOutputParser |
|
from langchain.prompts import ChatPromptTemplate |
|
from langchain.schema.runnable.config import RunnableConfig |
|
from langchain_core.messages import HumanMessage |
|
from langgraph.graph import END, StateGraph, START |
|
from langgraph.graph.message import MessagesState |
|
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage |
|
|
|
|
|
from langchain_core.tools import tool |
|
from langchain_community.tools.arxiv.tool import ArxivQueryRun |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_community.tools.yahoo_finance_news import YahooFinanceNewsTool |
|
|
|
|
|
import chainlit as cl |
|
|
|
@tool |
|
def multiply(first_int: int, second_int: int) -> int: |
|
"""Multiply two integers together.""" |
|
return first_int * second_int |
|
|
|
yahoo_finance_news = YahooFinanceNewsTool() |
|
tavily_search = TavilySearchResults(max_results=5) |
|
arxiv_query = ArxivQueryRun() |
|
|
|
tools = [ |
|
yahoo_finance_news, |
|
tavily_search, |
|
arxiv_query, |
|
multiply, |
|
] |
|
|
|
model = ChatOpenAI(model_name="gpt-4o-mini", temperature=0) |
|
final_model = ChatOpenAI(model_name="gpt-4o-mini", temperature=0) |
|
|
|
model = model.bind_tools(tools) |
|
|
|
|
|
|
|
|
|
final_model = final_model.with_config(tags=["final_node"]) |
|
tool_node = ToolNode(tools=tools) |
|
|
|
def should_continue(state: MessagesState) -> Literal["tools", "final"]: |
|
messages = state["messages"] |
|
last_message = messages[-1] |
|
|
|
if last_message.tool_calls: |
|
return "tools" |
|
|
|
return "final" |
|
|
|
|
|
def call_model(state: MessagesState): |
|
messages = state["messages"] |
|
response = model.invoke(messages) |
|
|
|
return {"messages": [response]} |
|
|
|
|
|
def call_final_model(state: MessagesState): |
|
messages = state["messages"] |
|
last_ai_message = messages[-1] |
|
response = final_model.invoke( |
|
[ |
|
SystemMessage("Provide a summary in point form notes of the following:"), |
|
HumanMessage(last_ai_message.content), |
|
] |
|
) |
|
|
|
response.id = last_ai_message.id |
|
return {"messages": [response]} |
|
|
|
|
|
builder = StateGraph(MessagesState) |
|
|
|
builder.add_node("agent", call_model) |
|
builder.add_node("tools", tool_node) |
|
|
|
builder.add_node("final", call_final_model) |
|
|
|
builder.add_edge(START, "agent") |
|
builder.add_conditional_edges( |
|
"agent", |
|
should_continue, |
|
) |
|
|
|
builder.add_edge("tools", "agent") |
|
builder.add_edge("final", END) |
|
|
|
graph = builder.compile() |
|
|
|
@cl.on_chat_start |
|
async def on_chat_start(): |
|
model = ChatOpenAI(streaming=True) |
|
prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
( |
|
"system", |
|
"You're a very knowledgeable agent with access to several tools to get recent data and multiply numbers.", |
|
), |
|
("human", "{question}"), |
|
] |
|
) |
|
runnable = prompt | model | StrOutputParser() |
|
cl.user_session.set("runnable", runnable) |
|
|
|
@cl.on_message |
|
async def on_message(msg: cl.Message): |
|
config = {} |
|
cb = cl.LangchainCallbackHandler() |
|
final_answer = cl.Message(content="") |
|
|
|
for msg, metadata in graph.stream({"messages": [HumanMessage(content=msg.content)]}, stream_mode="messages", config=RunnableConfig(callbacks=[cb], **config)): |
|
if ( |
|
msg.content |
|
and not isinstance(msg, HumanMessage) |
|
and metadata["langgraph_node"] == "final" |
|
): |
|
await final_answer.stream_token(msg.content) |
|
|
|
await final_answer.send() |
|
|