File size: 5,782 Bytes
d3f3fad 5c45105 d3f3fad 5c45105 d3f3fad 5c45105 d3f3fad 5c45105 d3f3fad 5c45105 d3f3fad 5c45105 d3f3fad 5c45105 d3f3fad 98ea928 d3f3fad 07bf00f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import functools, operator
from IPython.display import Image, display
from langchain_core.messages import (
AIMessage,
BaseMessage,
ToolMessage,
HumanMessage,
)
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.tools import tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_experimental.utilities import PythonREPL
from langchain_openai import ChatOpenAI
from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolNode
from typing import Annotated, Literal, Sequence, TypedDict
from typing_extensions import TypedDict
def create_agent(llm, tools, system_message: str):
"""Create an agent."""
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful AI assistant, collaborating with other assistants."
" Use the provided tools to progress towards answering the question."
" If you are unable to fully answer, that's OK, another assistant with different tools "
" will help where you left off. Execute what you can to make progress."
" If you or any of the other assistants have the final answer or deliverable,"
" prefix your response with FINAL ANSWER so the team knows to stop."
" You have access to the following tools: {tool_names}.\n{system_message}",
),
MessagesPlaceholder(variable_name="messages"),
]
)
prompt = prompt.partial(system_message=system_message)
prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
return prompt | llm.bind_tools(tools)
@tool
def python_repl(code: Annotated[str, "The python code to execute to generate your chart."]):
"""Use this to execute python code. If you want to see the output of a value,
you should print it out with `print(...)`. This is visible to the user."""
try:
result = repl.run(code)
except BaseException as e:
return f"Failed to execute. Error: {repr(e)}"
result_str = f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}"
return (
result_str + "\n\nIf you have completed all tasks, respond with FINAL ANSWER."
)
# This defines the object that is passed between each node
# in the graph. We will create different nodes for each agent and tool
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
sender: str
# Helper function to create a node for a given agent
def agent_node(state, agent, name):
result = agent.invoke(state)
# We convert the agent output into a format that is suitable to append to the global state
if isinstance(result, ToolMessage):
pass
else:
result = AIMessage(**result.dict(exclude={"type", "name"}), name=name)
return {
"messages": [result],
# Since we have a strict workflow, we can
# track the sender so we know who to pass to next.
"sender": name,
}
def router(state) -> Literal["call_tool", "__end__", "continue"]:
# This is the router
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
# The previous agent is invoking a tool
return "call_tool"
if "FINAL ANSWER" in last_message.content:
# Any agent decided the work is done
return "__end__"
return "continue"
def run_multi_agent(prompt):
tavily_tool = TavilySearchResults(max_results=5)
repl = PythonREPL()
llm = ChatOpenAI(model="gpt-4o")
# Research agent and node
research_agent = create_agent(
llm,
[tavily_tool],
system_message="You should provide accurate data for the chart_generator to use.",
)
research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")
# chart_generator
chart_agent = create_agent(
llm,
[python_repl],
system_message="Any charts you display will be visible by the user.",
)
chart_node = functools.partial(agent_node, agent=chart_agent, name="chart_generator")
tools = [tavily_tool, python_repl]
tool_node = ToolNode(tools)
workflow = StateGraph(AgentState)
workflow.add_node("Researcher", research_node)
workflow.add_node("chart_generator", chart_node)
workflow.add_node("call_tool", tool_node)
workflow.add_conditional_edges(
"Researcher",
router,
{"continue": "chart_generator", "call_tool": "call_tool", "__end__": END},
)
workflow.add_conditional_edges(
"chart_generator",
router,
{"continue": "Researcher", "call_tool": "call_tool", "__end__": END},
)
workflow.add_conditional_edges(
"call_tool",
# Each agent node updates the 'sender' field
# the tool calling node does not, meaning
# this edge will route back to the original agent
# who invoked the tool
lambda x: x["sender"],
{
"Researcher": "Researcher",
"chart_generator": "chart_generator",
},
)
workflow.set_entry_point("Researcher")
graph = workflow.compile()
try:
display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
except:
# This requires some extra dependencies and is optional
pass
events = graph.stream(
{
"messages": [
HumanMessage(
content=prompt
)
],
},
# Maximum number of steps to take in the graph
{"recursion_limit": 150},
)
for s in events:
print(s)
print("----")
return "DONE" |