Spaces:
Sleeping
Sleeping
File size: 3,656 Bytes
8f81a66 a7de757 5c47dcb a7de757 5c47dcb a7de757 5c47dcb a7de757 5c47dcb a7de757 5c47dcb a7de757 5c47dcb a7de757 5c47dcb a7de757 5c47dcb a7de757 5c47dcb a7de757 5c47dcb a7de757 5c47dcb a7de757 5c47dcb a7de757 5c47dcb a7de757 5c47dcb a7de757 5c47dcb a7de757 5c47dcb a7de757 1caca4d a7de757 1caca4d a7de757 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import os
from typing import Literal
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolNode
from langchain.schema.runnable.config import RunnableConfig
from langchain_core.messages import HumanMessage
import chainlit as cl
@tool
def get_weather(city: Literal["nyc", "sf"]):
"""Use this to get weather information."""
if city == "nyc":
return "It might be cloudy in nyc"
elif city == "sf":
return "It's always sunny in sf"
else:
raise AssertionError("Unknown city")
tools = [get_weather]
model = ChatOpenAI(base_url='https://albert.api.etalab.gouv.fr/v1', api_key=os.environ['API_KEY'], model_name="mistralai/Mistral-Small-3.1-24B-Instruct-2503", temperature=0)
final_model = ChatOpenAI(base_url='https://albert.api.etalab.gouv.fr/v1', api_key=os.environ['API_KEY'], model_name="mistralai/Mistral-Small-3.1-24B-Instruct-2503", temperature=0)
model = model.bind_tools(tools)
# NOTE: this is where we're adding a tag that we'll can use later to filter the model stream events to only the model called in the final node.
# This is not necessary if you call a single LLM but might be important in case you call multiple models within the node and want to filter events
# from only one of them.
final_model = final_model.with_config(tags=["final_node"])
tool_node = ToolNode(tools=tools)
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import MessagesState
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage
def should_continue(state: MessagesState) -> Literal["tools", "final"]:
messages = state["messages"]
last_message = messages[-1]
# If the LLM makes a tool call, then we route to the "tools" node
if last_message.tool_calls:
return "tools"
# Otherwise, we stop (reply to the user)
return "final"
def call_model(state: MessagesState):
messages = state["messages"]
response = model.invoke(messages)
# We return a list, because this will get added to the existing list
return {"messages": [response]}
def call_final_model(state: MessagesState):
messages = state["messages"]
last_ai_message = messages[-1]
response = final_model.invoke(
[
SystemMessage("Rewrite this in the voice of Al Roker"),
HumanMessage(last_ai_message.content),
]
)
# overwrite the last AI message from the agent
response.id = last_ai_message.id
return {"messages": [response]}
builder = StateGraph(MessagesState)
builder.add_node("agent", call_model)
builder.add_node("tools", tool_node)
# add a separate final node
builder.add_node("final", call_final_model)
builder.add_edge(START, "agent")
builder.add_conditional_edges(
"agent",
should_continue,
)
builder.add_edge("tools", "agent")
builder.add_edge("final", END)
graph = builder.compile()
@cl.on_message
async def on_message(msg: cl.Message):
config = {"configurable": {"thread_id": cl.context.session.id}}
cb = cl.LangchainCallbackHandler()
final_answer = cl.Message(content="")
for msg, metadata in graph.stream({"messages": [HumanMessage(content=msg.content)]}, stream_mode="messages", config=RunnableConfig(callbacks=[cb], **config)):
#print(msg.content)
if (
msg.content
and not isinstance(msg, HumanMessage)
and metadata["langgraph_node"] == "final"
):
print(msg.content)
await final_answer.stream_token(msg.content)
await final_answer.send() |