xiaosuhu86 commited on
Commit
058efce
·
1 Parent(s): f2c88dd

Update: Streamlit modification

Browse files
Files changed (3) hide show
  1. app.py +51 -43
  2. modules/graph.py +26 -0
  3. modules/st_callable_util.py +119 -0
app.py CHANGED
@@ -1,59 +1,67 @@
 
 
1
  import streamlit as st
2
- from modules.data_class import DataState
3
- from modules.tools import data_node
4
- from modules.nodes import chatbot_with_tools, human_node, maybe_exit_human_node, maybe_route_to_tools
5
 
6
- from langgraph.graph import StateGraph, START, END
 
7
 
8
- from IPython.display import Image, display
9
- from pprint import pprint
10
- from typing import Literal
11
 
12
- from langgraph.prebuilt import ToolNode
 
 
13
 
14
- from collections.abc import Iterable
15
- from IPython.display import display, clear_output
16
- import sys
17
 
 
 
 
18
 
19
- # Define the LangGraph chatbot
20
- graph_builder = StateGraph(DataState)
 
 
 
 
 
 
 
 
21
 
22
- # Add nodes
23
- graph_builder.add_node("chatbot_healthassistant", chatbot_with_tools)
24
- graph_builder.add_node("patient", human_node)
25
- graph_builder.add_node("documenting", data_node)
26
 
27
- # Define edges
28
- graph_builder.add_conditional_edges("chatbot_healthassistant", maybe_route_to_tools)
29
- graph_builder.add_conditional_edges("patient", maybe_exit_human_node)
30
- graph_builder.add_edge("documenting", "chatbot_healthassistant")
31
- graph_builder.add_edge(START, "chatbot_healthassistant")
32
 
33
- # Compile the graph
34
- graph_with_order_tools = graph_builder.compile()
 
35
 
36
- # Streamlit UI
37
- st.title("LangGraph Chatbot")
38
- st.markdown("Chat with an AI-powered health assistant.")
 
 
39
 
40
- # Initialize session state
41
  if "messages" not in st.session_state:
42
- st.session_state.messages = []
43
-
44
- user_input = st.text_input("You:", key="input")
45
 
46
- if st.button("Send"):
47
- if user_input:
48
- # Add user input to history
49
- st.session_state.messages.append(("user", user_input))
 
 
 
 
50
 
51
- # Run LangGraph chatbot
52
- state = DataState(messages=st.session_state.messages, data={}, finished=False)
53
- for output in graph_with_order_tools.stream(state):
54
- response = output["messages"][-1] # Get the last chatbot response
55
- st.session_state.messages.append(("ai", response))
56
 
57
- # Display chat history
58
- for sender, message in st.session_state.messages:
59
- st.write(f"**{sender}:** {message}")
 
 
 
 
1
+ import os
2
+
3
  import streamlit as st
4
+ from langchain_core.messages import AIMessage, HumanMessage
 
 
5
 
6
+ from modules.graph import invoke_our_graph
7
+ from modules.st_callable_util import get_streamlit_cb # Utility function to get a Streamlit callback handler with context
8
 
 
 
 
9
 
10
+ # Streamlit UI
11
+ st.title("Paintrek Medical Assistant")
12
+ st.markdown("Chat with an AI-powered health assistant.")
13
 
 
 
 
14
 
15
+ # Initialize the expander state
16
+ if "expander_open" not in st.session_state:
17
+ st.session_state.expander_open = True
18
 
19
+ # Check if the OpenAI API key is set
20
+ if not os.getenv('GOOGLE_API_KEY'):
21
+ # If not, display a sidebar input for the user to provide the API key
22
+ st.sidebar.header("GOOGLE_API_KEY Setup")
23
+ api_key = st.sidebar.text_input(label="API Key", type="password", label_visibility="collapsed")
24
+ os.environ["GOOGLE_API_KEY"] = api_key
25
+ # If no key is provided, show an info message and stop further execution and wait till key is entered
26
+ if not api_key:
27
+ st.info("Please enter your GOOGLE_API_KEY in the sidebar.")
28
+ st.stop()
29
 
 
 
 
 
30
 
31
+ # Capture user input from chat input
32
+ prompt = st.chat_input()
 
 
 
33
 
34
+ # Toggle expander state based on user input
35
+ if prompt is not None:
36
+ st.session_state.expander_open = False # Close the expander when the user starts typing
37
 
38
+ # st write magic
39
+ with st.expander(label="Paintrek Bot", expanded=st.session_state.expander_open):
40
+ """
41
+ At any time you can type 'q' or 'quit' to quit.
42
+ """
43
 
44
+ # Initialize chat messages in session state
45
  if "messages" not in st.session_state:
46
+ st.session_state["messages"] = [AIMessage(content="Welcome to the Paintrek world. I am a health assistant, an interactive clinical recording system. I will ask you questions about your pain and related symptoms and record your responses. I will then store this information securely. At any time, you can type `q` to quit.")]
 
 
47
 
48
+ # Loop through all messages in the session state and render them as a chat on every st.refresh mech
49
+ for msg in st.session_state.messages:
50
+ # https://docs.streamlit.io/develop/api-reference/chat/st.chat_message
51
+ # we store them as AIMessage and HumanMessage as its easier to send to LangGraph
52
+ if isinstance(msg, AIMessage):
53
+ st.chat_message("assistant").write(msg.content)
54
+ elif isinstance(msg, HumanMessage):
55
+ st.chat_message("user").write(msg.content)
56
 
57
+ # Handle user input if provided
58
+ if prompt:
59
+ st.session_state.messages.append(HumanMessage(content=prompt))
60
+ st.chat_message("user").write(prompt)
 
61
 
62
+ with st.chat_message("assistant"):
63
+ # create a new placeholder for streaming messages and other events, and give it context
64
+ st_callback = get_streamlit_cb(st.container())
65
+ response = invoke_our_graph(st.session_state.messages, [st_callback])
66
+ st.session_state.messages.append(AIMessage(content=response["messages"][-1].content)) # Add that last message to the st_message_state
67
+ st.write(response["messages"][-1].content) # Write the message inside the chat_message context
modules/graph.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.data_class import DataState
2
+ from modules.tools import data_node
3
+ from modules.nodes import chatbot_with_tools, human_node, maybe_exit_human_node, maybe_route_to_tools
4
+ from langgraph.graph import StateGraph, START, END
5
+
6
+ # Define the LangGraph chatbot
7
+ graph_builder = StateGraph(DataState)
8
+
9
+ # Add nodes
10
+ graph_builder.add_node("chatbot_healthassistant", chatbot_with_tools)
11
+ graph_builder.add_node("patient", human_node)
12
+ graph_builder.add_node("documenting", data_node)
13
+
14
+ # Define edges
15
+ graph_builder.add_conditional_edges("chatbot_healthassistant", maybe_route_to_tools)
16
+ graph_builder.add_conditional_edges("patient", maybe_exit_human_node)
17
+ graph_builder.add_edge("documenting", "chatbot_healthassistant")
18
+ graph_builder.add_edge(START, "chatbot_healthassistant")
19
+
20
+ # Compile the state graph into a runnable object
21
+ graph_with_order_tools = graph_builder.compile()
22
+
23
+ def invoke_our_graph(st_messages, callables):
24
+ if not isinstance(callables, list):
25
+ raise TypeError("callables must be a list")
26
+ return graph_with_order_tools.invoke({"messages": st_messages}, config={"callbacks": callables})
modules/st_callable_util.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, TypeVar, Any, Dict, Optional
2
+ import inspect
3
+
4
+ from streamlit.runtime.scriptrunner import add_script_run_ctx, get_script_run_ctx
5
+ from streamlit.delta_generator import DeltaGenerator
6
+
7
+ from langchain_core.callbacks.base import BaseCallbackHandler
8
+ import streamlit as st
9
+
10
+
11
+ # Define a function to create a callback handler for Streamlit that updates the UI dynamically
12
+ def get_streamlit_cb(parent_container: DeltaGenerator) -> BaseCallbackHandler:
13
+ """
14
+ Creates a Streamlit callback handler that updates the provided Streamlit container with new tokens.
15
+ Args:
16
+ parent_container (DeltaGenerator): The Streamlit container where the text will be rendered.
17
+ Returns:
18
+ BaseCallbackHandler: An instance of a callback handler configured for Streamlit.
19
+ """
20
+
21
+ # Define a custom callback handler class for managing and displaying stream events in Streamlit
22
+ class StreamHandler(BaseCallbackHandler):
23
+ """
24
+ Custom callback handler for Streamlit that updates a Streamlit container with new tokens.
25
+ """
26
+
27
+ def __init__(self, container: st.delta_generator.DeltaGenerator, initial_text: str = ""):
28
+ """
29
+ Initializes the StreamHandler with a Streamlit container and optional initial text.
30
+ Args:
31
+ container (st.delta_generator.DeltaGenerator): The Streamlit container where text will be rendered.
32
+ initial_text (str): Optional initial text to start with in the container.
33
+ """
34
+ self.container = container # The Streamlit container to update
35
+ self.thoughts_placeholder = self.container.container() # container to hold tool_call renders
36
+ self.tool_output_placeholder = None # placeholder for the output of the tool call to be in the expander
37
+ self.token_placeholder = self.container.empty() # for token streaming
38
+ self.text = initial_text # The text content to display, starting with initial text
39
+
40
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
41
+ """
42
+ Callback method triggered when a new token is received (e.g., from a language model).
43
+ Args:
44
+ token (str): The new token received.
45
+ **kwargs: Additional keyword arguments.
46
+ """
47
+ self.text += token # Append the new token to the existing text
48
+ self.token_placeholder.write(self.text)
49
+
50
+ def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> None:
51
+ """
52
+ Run when the tool starts running.
53
+ Args:
54
+ serialized (Dict[str, Any]): The serialized tool.
55
+ input_str (str): The input string.
56
+ kwargs (Any): Additional keyword arguments.
57
+ """
58
+ with self.thoughts_placeholder:
59
+ status_placeholder = st.empty() # Placeholder to show the tool's status
60
+ with status_placeholder.status("Calling Tool...", expanded=True) as s:
61
+ st.write("called ", serialized["name"]) # Show which tool is being called
62
+ st.write("tool description: ", serialized["description"])
63
+ st.write("tool input: ")
64
+ st.code(input_str) # Display the input data sent to the tool
65
+ st.write("tool output: ")
66
+ # Placeholder for tool output that will be updated later below
67
+ self.tool_output_placeholder = st.empty()
68
+ s.update(label="Completed Calling Tool!", expanded=False) # Update the status once done
69
+
70
+ def on_tool_end(self, output: Any, **kwargs: Any) -> Any:
71
+ """
72
+ Run when the tool ends.
73
+ Args:
74
+ output (Any): The output from the tool.
75
+ kwargs (Any): Additional keyword arguments.
76
+ """
77
+ # We assume that `on_tool_end` comes after `on_tool_start`, meaning output_placeholder exists
78
+ if self.tool_output_placeholder:
79
+ self.tool_output_placeholder.code(output.content) # Display the tool's output
80
+
81
+ # Define a type variable for generic type hinting in the decorator, to maintain
82
+ # input function and wrapped function return type
83
+ fn_return_type = TypeVar('fn_return_type')
84
+
85
+ # Decorator function to add the Streamlit execution context to a function
86
+ def add_streamlit_context(fn: Callable[..., fn_return_type]) -> Callable[..., fn_return_type]:
87
+ """
88
+ Decorator to ensure that the decorated function runs within the Streamlit execution context.
89
+ Args:
90
+ fn (Callable[..., fn_return_type]): The function to be decorated.
91
+ Returns:
92
+ Callable[..., fn_return_type]: The decorated function that includes the Streamlit context setup.
93
+ """
94
+ ctx = get_script_run_ctx() # Retrieve the current Streamlit script execution context
95
+
96
+ def wrapper(*args, **kwargs) -> fn_return_type:
97
+ """
98
+ Wrapper function that adds the Streamlit context and then calls the original function.
99
+ Args:
100
+ *args: Positional arguments to pass to the original function.
101
+ **kwargs: Keyword arguments to pass to the original function.
102
+ Returns:
103
+ fn_return_type: The result from the original function.
104
+ """
105
+ add_script_run_ctx(ctx=ctx) # Add the Streamlit context to the current execution
106
+ return fn(*args, **kwargs) # Call the original function with its arguments
107
+
108
+ return wrapper
109
+
110
+ # Create an instance of the custom StreamHandler with the provided Streamlit container
111
+ st_cb = StreamHandler(parent_container)
112
+
113
+ # Iterate over all methods of the StreamHandler instance
114
+ for method_name, method_func in inspect.getmembers(st_cb, predicate=inspect.ismethod):
115
+ if method_name.startswith('on_'): # Identify callback methods
116
+ setattr(st_cb, method_name, add_streamlit_context(method_func)) # Wrap and replace the method
117
+
118
+ # Return the fully configured StreamHandler instance with the context-aware callback methods
119
+ return st_cb