Spaces:
Sleeping
Sleeping
import os | |
from typing import Literal | |
from langchain_core.tools import tool | |
from langchain_openai import ChatOpenAI | |
from langgraph.prebuilt import ToolNode | |
from langchain.schema.runnable.config import RunnableConfig | |
from langchain_core.messages import HumanMessage | |
import chainlit as cl | |
def get_weather(city: Literal["nyc", "sf"]): | |
"""Use this to get weather information.""" | |
if city == "nyc": | |
return "It might be cloudy in nyc" | |
elif city == "sf": | |
return "It's always sunny in sf" | |
else: | |
raise AssertionError("Unknown city") | |
tools = [get_weather] | |
model = ChatOpenAI(base_url='https://albert.api.etalab.gouv.fr/v1', api_key=os.environ['API_KEY'], model_name="mistralai/Mistral-Small-3.1-24B-Instruct-2503", temperature=0) | |
final_model = ChatOpenAI(base_url='https://albert.api.etalab.gouv.fr/v1', api_key=os.environ['API_KEY'], model_name="mistralai/Mistral-Small-3.1-24B-Instruct-2503", temperature=0) | |
model = model.bind_tools(tools) | |
# NOTE: this is where we're adding a tag that we'll can use later to filter the model stream events to only the model called in the final node. | |
# This is not necessary if you call a single LLM but might be important in case you call multiple models within the node and want to filter events | |
# from only one of them. | |
final_model = final_model.with_config(tags=["final_node"]) | |
tool_node = ToolNode(tools=tools) | |
from typing import Annotated | |
from typing_extensions import TypedDict | |
from langgraph.graph import END, StateGraph, START | |
from langgraph.graph.message import MessagesState | |
from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage | |
def should_continue(state: MessagesState) -> Literal["tools", "final"]: | |
messages = state["messages"] | |
last_message = messages[-1] | |
# If the LLM makes a tool call, then we route to the "tools" node | |
if last_message.tool_calls: | |
return "tools" | |
# Otherwise, we stop (reply to the user) | |
return "final" | |
def call_model(state: MessagesState): | |
messages = state["messages"] | |
response = model.invoke(messages) | |
# We return a list, because this will get added to the existing list | |
return {"messages": [response]} | |
def call_final_model(state: MessagesState): | |
messages = state["messages"] | |
last_ai_message = messages[-1] | |
response = final_model.invoke( | |
[ | |
SystemMessage("Rewrite this in the voice of Al Roker"), | |
HumanMessage(last_ai_message.content), | |
] | |
) | |
# overwrite the last AI message from the agent | |
response.id = last_ai_message.id | |
return {"messages": [response]} | |
builder = StateGraph(MessagesState) | |
builder.add_node("agent", call_model) | |
builder.add_node("tools", tool_node) | |
# add a separate final node | |
builder.add_node("final", call_final_model) | |
builder.add_edge(START, "agent") | |
builder.add_conditional_edges( | |
"agent", | |
should_continue, | |
) | |
builder.add_edge("tools", "agent") | |
builder.add_edge("final", END) | |
graph = builder.compile() | |
async def on_message(msg: cl.Message): | |
config = {"configurable": {"thread_id": cl.context.session.id}} | |
cb = cl.LangchainCallbackHandler() | |
final_answer = cl.Message(content="") | |
for msg, metadata in graph.stream({"messages": [HumanMessage(content=msg.content)]}, stream_mode="messages", config=RunnableConfig(callbacks=[cb], **config)): | |
#print(msg.content) | |
if ( | |
msg.content | |
and not isinstance(msg, HumanMessage) | |
and metadata["langgraph_node"] == "final" | |
): | |
print(msg.content) | |
await final_answer.stream_token(msg.content) | |
await final_answer.send() |