|
import asyncio |
|
from typing import TypedDict, Annotated |
|
from langchain_openai import ChatOpenAI |
|
from langgraph.prebuilt import ToolNode |
|
from langgraph.graph import StateGraph, END |
|
from langgraph.graph.message import add_messages |
|
from langchain_core.messages import HumanMessage |
|
from langchain.schema.runnable import RunnableLambda |
|
|
|
|
|
from langchain_core.tools import tool |
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
from langchain_community.tools.arxiv.tool import ArxivQueryRun |
|
from langchain_community.tools.openweathermap.tool import OpenWeatherMapQueryRun |
|
|
|
@tool |
|
def multiply(first_int: int, second_int: int) -> int: |
|
"""Multiply two integers together.""" |
|
return first_int * second_int |
|
|
|
tavily_search = TavilySearchResults(max_results=5) |
|
weather_query = OpenWeatherMapQueryRun() |
|
arxiv_query = ArxivQueryRun() |
|
|
|
tool_belt = [ |
|
tavily_search, |
|
weather_query, |
|
arxiv_query, |
|
multiply, |
|
] |
|
|
|
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) |
|
llm = llm.bind_tools(tool_belt) |
|
|
|
class AgentState(TypedDict): |
|
messages: Annotated[list, add_messages] |
|
loop_count: int |
|
|
|
async def call_model(state): |
|
messages = state["messages"] |
|
response = llm.invoke(messages) |
|
return {"messages" : [response]} |
|
|
|
def should_continue(state): |
|
last_message = state["messages"][-1] |
|
if last_message.tool_calls: |
|
return "action" |
|
return END |
|
|
|
tool_node = ToolNode(tool_belt) |
|
|
|
graph = StateGraph(AgentState) |
|
graph.add_node("agent", call_model) |
|
graph.add_node("action", tool_node) |
|
|
|
graph.set_entry_point("agent") |
|
graph.add_conditional_edges( |
|
"agent", |
|
should_continue |
|
) |
|
graph.add_edge("action", "agent") |
|
|
|
tool_call_graph = graph.compile() |
|
|
|
async def main(): |
|
inputs = {"messages" : [HumanMessage(content="Search Arxiv for the QLoRA paper, then search each of the authors to find out their latest Tweet using Tavily! and solve 5 x 5 please.")]} |
|
|
|
async for chunk in tool_call_graph.astream(inputs, stream_mode="updates"): |
|
for node, values in chunk.items(): |
|
|
|
|
|
|
|
print(values["messages"]) |
|
print('\n') |
|
|
|
|
|
asyncio.run(main()) |