ai / human.py
kevinhug's picture
human-in-the-loop
2379f57
from typing import Annotated, Any, Literal
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.types import interrupt, Command
from typing_extensions import TypedDict
"""
from langchain_anthropic import ChatAnthropic
from langchain_ollama.llms import OllamaLLM
from langchain_experimental.llms.ollama_functions import OllamaFunctions
llm = OllamaFunctions(model="qwen2.5", format="json")
llm_with_tools = llm #.bind_tools(tools)
"""
from langchain_groq import ChatGroq
llm = ChatGroq(
model="gemma2-9b-it", #"llama-3.1-8b-instant",
temperature=0.4,
max_tokens=None,
timeout=None,
max_retries=2,
# other params...
)
template = """Question: {question}
Answer: Let's think step by step."""
prompt = ChatPromptTemplate.from_template(template)
# model = OllamaLLM(model="deepseek-r1")
chain = prompt | llm
# print(chain.invoke({"question": "Explain like I'm 5 for capacity planning?"}))
@tool
def human_assistance(query: str) -> str:
"""Request assistance from a human."""
human_response = interrupt({"query": query})
return human_response["data"]
tool = TavilySearchResults(max_results=2)
tools = [tool, human_assistance]
llm_with_tools=llm.bind_tools(tools)
# llm = OllamaLLM(model="deepseek-r1") #ChatAnthropic(model="claude-3-5-sonnet-20240620")
class State(TypedDict):
messages: Annotated[list, add_messages]
persona: str
email: str
release: Literal['approve', 'reject']
graph_builder = StateGraph(State)
def write_email(state: State):
prompt = f"""Write an promotional personalized email for this persona and offer financial education and setup a meeting for financial advisor, Only the email nothing else:
{state["persona"]}
"""
email = llm_with_tools.invoke(prompt)
# Because we will be interrupting during tool execution,
# we disable parallel tool calling to avoid repeating any
# tool invocations when we resume.
# assert len(email.tool_calls) <= 1
return Command(update={"email": email.content})
graph_builder.add_node("write_email", write_email)
def delivery(state: State):
print(f"""Delivering: {state['email']}""")
return Command(update={"messages": ["Email delivered to customer"]})
graph_builder.add_node("delivery", delivery)
def human_approval(state: State) -> Command[Literal["delivery", END]]:
is_approved = interrupt(
"Approval for release the promotional email to customer? (type: approved or rejected):"
)
if is_approved == "approved":
return Command(goto="delivery", update={"release": "approved"})
else:
return Command(goto=END, update={"release": "rejected"})
# Add the node to the graph in an appropriate location
# and connect it to the relevant nodes.
graph_builder.add_node("human_approval", human_approval)
graph_builder.add_edge(START, "write_email")
graph_builder.add_edge("write_email", "human_approval")
graph_builder.add_edge("delivery", END)
checkpointer = MemorySaver()
graph = graph_builder.compile(checkpointer=checkpointer)
def email(persona, campaign, history):
thread_config = {"configurable": {"thread_id": campaign}}
for event in graph.stream({"persona": persona}, config=thread_config):
for value in event.values():
return r"Assistant: ", value, r"Value: ", graph.get_state(thread_config).values
def feedback(deliver, campaign, history):
thread_config = {"configurable": {"thread_id": campaign}}
for event in graph.stream(Command(resume=deliver), config=thread_config):
for value in event.values():
return r"Assistant: ", value, r"Value: ", graph.get_state(thread_config).values
'''
from IPython.display import Image, display
try:
display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
# This requires some extra dependencies and is optional
pass
'''
def campaign(user_input: Any, id: str):
thread_config = {"configurable": {"thread_id": id}}
for event in graph.stream(user_input, config=thread_config):
for value in event.values():
print("Assistant:", value, "Value: ", graph.get_state(thread_config).values)
"""
campaign({"persona": "My mortgage rate is 9%, I cannot afford it anymore, I need to refinance and I'm unemploy right now."}, "MOR")
campaign({"persona": "my credit card limit is too low, I need a card with bigger limit and low fee"}, "CARD")
campaign(Command(resume="approved"), "MOR")
"""
while False:
try:
user_input = input("User: ")
if user_input.lower() in ["quit", "exit", "q"]:
print("Goodbye!")
break
campaign(user_input, "MORT")
# stream_graph_updates(user_input)
except Exception as e:
# fallback if input() is not available
user_input = "What do you know about LangGraph?"
print("User: " + user_input)
campaign(user_input, "MORT")
# stream_graph_updates(user_input)
break