import asyncio from typing import TypedDict, Annotated from langchain_openai import ChatOpenAI from langgraph.prebuilt import ToolNode from langgraph.graph import StateGraph, END from langgraph.graph.message import add_messages from langchain_core.messages import HumanMessage from langchain.schema.runnable import RunnableLambda # tools from langchain_core.tools import tool from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.tools.arxiv.tool import ArxivQueryRun from langchain_community.tools.openweathermap.tool import OpenWeatherMapQueryRun @tool def multiply(first_int: int, second_int: int) -> int: """Multiply two integers together.""" return first_int * second_int tavily_search = TavilySearchResults(max_results=5) weather_query = OpenWeatherMapQueryRun() arxiv_query = ArxivQueryRun() tool_belt = [ tavily_search, weather_query, arxiv_query, multiply, ] llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) llm = llm.bind_tools(tool_belt) class AgentState(TypedDict): messages: Annotated[list, add_messages] loop_count: int async def call_model(state): messages = state["messages"] response = llm.invoke(messages) return {"messages" : [response]} def should_continue(state): last_message = state["messages"][-1] if last_message.tool_calls: return "action" return END tool_node = ToolNode(tool_belt) graph = StateGraph(AgentState) graph.add_node("agent", call_model) graph.add_node("action", tool_node) graph.set_entry_point("agent") graph.add_conditional_edges( "agent", should_continue ) graph.add_edge("action", "agent") tool_call_graph = graph.compile() async def main(): inputs = {"messages" : [HumanMessage(content="Search Arxiv for the QLoRA paper, then search each of the authors to find out their latest Tweet using Tavily! and solve 5 x 5 please.")]} async for chunk in tool_call_graph.astream(inputs, stream_mode="updates"): for node, values in chunk.items(): # print(f"Receiving update from node: '{node}'") # if node == "action": # print(f"Tool Used: {values['messages'][0].name}") print(values["messages"]) print('\n') asyncio.run(main())