xiaosuhu86's picture
update: streamlit version-still working on
18793f5
import pprint
import random
from langchain_core.tools import tool
from modules.data_class import DataState
from langgraph.prebuilt import InjectedState
from langchain_core.messages.tool import ToolMessage
# These functions have no body; LangGraph does not allow @tools to update
# the conversation state, so you will implement a separate node to handle
# state updates. Using @tools is still very convenient for defining the tool
# schema, so empty functions have been defined that will be bound to the LLM
# but their implementation is deferred to the order_node.
@tool
def patient_id(name: str, DOB: str, gender: str, contact: str, emergency_contact: str) -> str:
"""Collecting basic patient identification information including:
- Basic information (name, DOB, gender, contact details)
- Emergency contact information
Returns:
The updated data with the patient ID information added.
"""
@tool
def symptom(main_symptom: str, symptom_length: str) -> str:
"""Collecting patient's main symptom assessment including:
- Primary symptoms
- Duration of the symptoms
Returns:
The updated data with the patient's symptom information added.
"""
@tool
def pain(pain_location: str, pain_side: str, pain_intensity: int, pain_description: str, start_time: str, radiation: bool, triggers: str, symptom: str) -> str:
"""Collecting patient's pain status including:
- Pain location using body mapping (head, arms, hands, trunk, legs, feet)
- Pain side (left or right)
- Pain intensity (0-10 scale for each location)
- Pain characteristics and patterns
- Onset time
- Radiation patterns
- Triggering factors
- Associated symptoms
Returns:
The updated data with the patient's pain status added.
"""
@tool
def medical_hist(medical_condition: str, first_time: str, surgery_history: str, medication: str, allergy: str) -> str:
"""Collecting patient's medical history including:
- Existing medical conditions
- First occurrence date
- Surgical history with dates
- Current medications
- Allergies
Returns:
The updated data with the patient's medical history added.
"""
@tool
def family_hist(family_history: str) -> str:
"""Collecting patient's family history
Returns:
The updated data with the patient's family history added.
"""
@tool
def social_hist(occupation: str, smoke: bool, alcohol: bool, drug: bool, support_system: str, living_condition: str) -> str:
"""Collecting patient's social history including:
- Occupation
- smoking or not
- alcohol use or not
- drug use or not
- living conditions
- support system
Returns:
The updated data with the patient's social history added.
"""
@tool
def review_system(weight_change: str, fever: bool, chill: bool, night_sweats: bool, sleep: str, gastrointestinal: str, urinary: str) -> str:
"""Collecting patient's review information including:
- Recent weight changes
- Constitutional symptoms (fever, chills, night sweats)
- Sleep patterns
- Gastrointestinal and urinary function
Returns:
The updated data with the patient's review.
"""
@tool
def pain_manage(pain_medication: str, specialist: bool, other_therapy: str, effectiveness: bool) -> str:
"""Collecting patient's pain management including:
- Current pain medications
- Specialist consultations
- Alternative therapies
- Treatment effectiveness
Returns:
The updated data with the patient's pain management.
"""
@tool
def functional(life_quality: str, limit_activity: str, mood: str) -> str:
"""Collecting patient's functional assement information including:
- Impact on quality of life
- Activity limitations
- Mood and emotional state
Returns:
The updated data with the patient's functional assessment information.
"""
@tool
def plan(goal: str, expectation: str, alternative_treatment: str) -> str:
"""Collecting patient's future treatment plan information including:
- Treatment goals
- Patient expectations
- Alternative treatment considerations
Returns:
The updated data with the patient's future treatment plan information.
"""
@tool
def confirm_data() -> str:
"""Asks the patient if the data intake is correct.
Returns:
The user's free-text response.
"""
@tool
def get_data() -> str:
"""Returns the users data so far. One item per line."""
@tool
def clear_data():
"""Removes all items from the user's order."""
@tool
def save_data() -> int:
"""Send the data into database.
Returns:
The status of data saving, finished.
"""
def data_node(state: DataState) -> DataState:
"""The ordering node. This is where the dataintake is manipulated."""
tool_msg = state.get("messages", [])[-1]
data = state.get("data", [])
outbound_msgs = []
data_saved = False
for tool_call in tool_msg.tool_calls:
if tool_call["name"] == "patient_id":
# Each order item is just a string. This is where it assembled as "drink (modifiers, ...)".
data["ID"]["name"]=tool_call["args"]["name"]
data["ID"]["DOB"]=tool_call["args"]["DOB"]
data["ID"]["gender"]=tool_call["args"]["gender"]
data["ID"]["contact"]=tool_call["args"]["contact"]
data["ID"]["emergency_contact"]=tool_call["args"]["emergency_contact"]
response = "\n".join(data)
elif tool_call["name"] == "symptom":
# Each order item is just a string. This is where it assembled as "drink (modifiers, ...)".
data["symptom"]["main_symptom"]=tool_call["args"]["main_symptom"]
data["symptom"]["symptom_length"]=tool_call["args"]["symptom_length"]
response = "\n".join(data)
elif tool_call["name"] == "pain":
data["pain"]["pain_location"] = tool_call["args"]["pain_location"]
data["pain"]["pain_side"] = tool_call["args"]["pain_side"]
data["pain"]["pain_intensity"] = tool_call["args"]["pain_intensity"]
data["pain"]["pain_description"] = tool_call["args"]["pain_description"]
data["pain"]["start_time"] = tool_call["args"]["start_time"]
data["pain"]["radiation"] = tool_call["args"]["radiation"]
data["pain"]["triggers"] = tool_call["args"]["triggers"]
data["pain"]["symptom"] = tool_call["args"]["symptom"]
response = "\n".join(data)
elif tool_call["name"] == "medical_hist":
data["medical_hist"]["medical_condition"] = tool_call["args"]["medical_condition"]
data["medical_hist"]["first_time"] = tool_call["args"]["first_time"]
data["medical_hist"]["surgery_history"] = tool_call["args"]["surgery_history"]
data["medical_hist"]["medication"] = tool_call["args"]["medication"]
data["medical_hist"]["allergy"] = tool_call["args"]["allergy"]
response = "\n".join(data)
elif tool_call["name"] == "family_hist":
data["family_hist"]["family_history"] = tool_call["args"]["family_history"]
response = "\n".join(data)
elif tool_call["name"] == "social_hist":
data["social_hist"]["occupation"] = tool_call["args"]["occupation"]
data["social_hist"]["smoke"] = tool_call["args"]["smoke"]
data["social_hist"]["alcohol"] = tool_call["args"]["alcohol"]
data["social_hist"]["drug"] = tool_call["args"]["drug"]
data["social_hist"]["support_system"] = tool_call["args"]["support_system"]
data["social_hist"]["living_condition"] = tool_call["args"]["living_condition"]
response = "\n".join(data)
elif tool_call["name"] == "review_system":
data["review_system"]["weight_change"] = tool_call["args"]["weight_change"]
data["review_system"]["fever"] = tool_call["args"]["fever"]
data["review_system"]["chill"] = tool_call["args"]["chill"]
data["review_system"]["night_sweats"] = tool_call["args"]["night_sweats"]
data["review_system"]["sleep"] = tool_call["args"]["sleep"]
data["review_system"]["gastrointestinal"] = tool_call["args"]["gastrointestinal"]
data["review_system"]["urinary"] = tool_call["args"]["urinary"]
response = "\n".join(data)
elif tool_call["name"] == "pain_manage":
data["pain_manage"]["pain_medication"] = tool_call["args"]["pain_medication"]
data["pain_manage"]["specialist"] = tool_call["args"]["specialist"]
data["pain_manage"]["other_therapy"] = tool_call["args"]["other_therapy"]
data["pain_manage"]["effectiveness"] = tool_call["args"]["effectiveness"]
response = "\n".join(data)
elif tool_call["name"] == "functional":
data["functional"]["life_quality"] = tool_call["args"]["life_quality"]
data["functional"]["limit_activity"] = tool_call["args"]["limit_activity"]
data["functional"]["mood"] = tool_call["args"]["mood"]
response = "\n".join(data)
elif tool_call["name"] == "plan":
data["plan"]["goal"] = tool_call["args"]["goal"]
data["plan"]["expectation"] = tool_call["args"]["expectation"]
data["plan"]["alternative_treatment"] = tool_call["args"]["alternative_treatment"]
response = "\n".join(data)
elif tool_call["name"] == "confirm_data":
# We could entrust the LLM to do order confirmation, but it is a good practice to
# show the user the exact data that comprises their order so that what they confirm
# precisely matches the order that goes to the kitchen - avoiding hallucination
# or reality skew.
# In a real scenario, this is where you would connect your POS screen to show the
# order to the user.
print("Your input data:")
if not data:
print(" (no items)")
print(state["data"])
response = input("Is this correct? ")
elif tool_call["name"] == "get_data":
response = "\n".join(data) if data else "(no data)"
elif tool_call["name"] == "clear_data":
data.clear()
response = None
elif tool_call["name"] == "save_data":
#order_text = "\n".join(order)
print("Saving the data!")
print(data)
# TODO(you!): Implement cafe.
data_saved = True
response = random.randint(1, 5) # ETA in minutes
else:
raise NotImplementedError(f'Unknown tool call: {tool_call["name"]}')
# Record the tool results as tool messages.
outbound_msgs.append(
ToolMessage(
content=response,
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
return {"messages": outbound_msgs, "data": data, "finished": data_saved}