Spaces:
Sleeping
Sleeping
Commit
·
058efce
1
Parent(s):
f2c88dd
Update: Streamlit modification
Browse files- app.py +51 -43
- modules/graph.py +26 -0
- modules/st_callable_util.py +119 -0
app.py
CHANGED
@@ -1,59 +1,67 @@
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
-
from
|
3 |
-
from modules.tools import data_node
|
4 |
-
from modules.nodes import chatbot_with_tools, human_node, maybe_exit_human_node, maybe_route_to_tools
|
5 |
|
6 |
-
from
|
|
|
7 |
|
8 |
-
from IPython.display import Image, display
|
9 |
-
from pprint import pprint
|
10 |
-
from typing import Literal
|
11 |
|
12 |
-
|
|
|
|
|
13 |
|
14 |
-
from collections.abc import Iterable
|
15 |
-
from IPython.display import display, clear_output
|
16 |
-
import sys
|
17 |
|
|
|
|
|
|
|
18 |
|
19 |
-
#
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
# Add nodes
|
23 |
-
graph_builder.add_node("chatbot_healthassistant", chatbot_with_tools)
|
24 |
-
graph_builder.add_node("patient", human_node)
|
25 |
-
graph_builder.add_node("documenting", data_node)
|
26 |
|
27 |
-
#
|
28 |
-
|
29 |
-
graph_builder.add_conditional_edges("patient", maybe_exit_human_node)
|
30 |
-
graph_builder.add_edge("documenting", "chatbot_healthassistant")
|
31 |
-
graph_builder.add_edge(START, "chatbot_healthassistant")
|
32 |
|
33 |
-
#
|
34 |
-
|
|
|
35 |
|
36 |
-
#
|
37 |
-
st.
|
38 |
-
|
|
|
|
|
39 |
|
40 |
-
# Initialize session state
|
41 |
if "messages" not in st.session_state:
|
42 |
-
st.session_state
|
43 |
-
|
44 |
-
user_input = st.text_input("You:", key="input")
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
st.session_state.messages.append(("ai", response))
|
56 |
|
57 |
-
|
58 |
-
for
|
59 |
-
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
import streamlit as st
|
4 |
+
from langchain_core.messages import AIMessage, HumanMessage
|
|
|
|
|
5 |
|
6 |
+
from modules.graph import invoke_our_graph
|
7 |
+
from modules.st_callable_util import get_streamlit_cb # Utility function to get a Streamlit callback handler with context
|
8 |
|
|
|
|
|
|
|
9 |
|
10 |
+
# Streamlit UI
|
11 |
+
st.title("Paintrek Medical Assistant")
|
12 |
+
st.markdown("Chat with an AI-powered health assistant.")
|
13 |
|
|
|
|
|
|
|
14 |
|
15 |
+
# Initialize the expander state
|
16 |
+
if "expander_open" not in st.session_state:
|
17 |
+
st.session_state.expander_open = True
|
18 |
|
19 |
+
# Check if the OpenAI API key is set
|
20 |
+
if not os.getenv('GOOGLE_API_KEY'):
|
21 |
+
# If not, display a sidebar input for the user to provide the API key
|
22 |
+
st.sidebar.header("GOOGLE_API_KEY Setup")
|
23 |
+
api_key = st.sidebar.text_input(label="API Key", type="password", label_visibility="collapsed")
|
24 |
+
os.environ["GOOGLE_API_KEY"] = api_key
|
25 |
+
# If no key is provided, show an info message and stop further execution and wait till key is entered
|
26 |
+
if not api_key:
|
27 |
+
st.info("Please enter your GOOGLE_API_KEY in the sidebar.")
|
28 |
+
st.stop()
|
29 |
|
|
|
|
|
|
|
|
|
30 |
|
31 |
+
# Capture user input from chat input
|
32 |
+
prompt = st.chat_input()
|
|
|
|
|
|
|
33 |
|
34 |
+
# Toggle expander state based on user input
|
35 |
+
if prompt is not None:
|
36 |
+
st.session_state.expander_open = False # Close the expander when the user starts typing
|
37 |
|
38 |
+
# st write magic
|
39 |
+
with st.expander(label="Paintrek Bot", expanded=st.session_state.expander_open):
|
40 |
+
"""
|
41 |
+
At any time you can type 'q' or 'quit' to quit.
|
42 |
+
"""
|
43 |
|
44 |
+
# Initialize chat messages in session state
|
45 |
if "messages" not in st.session_state:
|
46 |
+
st.session_state["messages"] = [AIMessage(content="Welcome to the Paintrek world. I am a health assistant, an interactive clinical recording system. I will ask you questions about your pain and related symptoms and record your responses. I will then store this information securely. At any time, you can type `q` to quit.")]
|
|
|
|
|
47 |
|
48 |
+
# Loop through all messages in the session state and render them as a chat on every st.refresh mech
|
49 |
+
for msg in st.session_state.messages:
|
50 |
+
# https://docs.streamlit.io/develop/api-reference/chat/st.chat_message
|
51 |
+
# we store them as AIMessage and HumanMessage as its easier to send to LangGraph
|
52 |
+
if isinstance(msg, AIMessage):
|
53 |
+
st.chat_message("assistant").write(msg.content)
|
54 |
+
elif isinstance(msg, HumanMessage):
|
55 |
+
st.chat_message("user").write(msg.content)
|
56 |
|
57 |
+
# Handle user input if provided
|
58 |
+
if prompt:
|
59 |
+
st.session_state.messages.append(HumanMessage(content=prompt))
|
60 |
+
st.chat_message("user").write(prompt)
|
|
|
61 |
|
62 |
+
with st.chat_message("assistant"):
|
63 |
+
# create a new placeholder for streaming messages and other events, and give it context
|
64 |
+
st_callback = get_streamlit_cb(st.container())
|
65 |
+
response = invoke_our_graph(st.session_state.messages, [st_callback])
|
66 |
+
st.session_state.messages.append(AIMessage(content=response["messages"][-1].content)) # Add that last message to the st_message_state
|
67 |
+
st.write(response["messages"][-1].content) # Write the message inside the chat_message context
|
modules/graph.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from modules.data_class import DataState
|
2 |
+
from modules.tools import data_node
|
3 |
+
from modules.nodes import chatbot_with_tools, human_node, maybe_exit_human_node, maybe_route_to_tools
|
4 |
+
from langgraph.graph import StateGraph, START, END
|
5 |
+
|
6 |
+
# Define the LangGraph chatbot
|
7 |
+
graph_builder = StateGraph(DataState)
|
8 |
+
|
9 |
+
# Add nodes
|
10 |
+
graph_builder.add_node("chatbot_healthassistant", chatbot_with_tools)
|
11 |
+
graph_builder.add_node("patient", human_node)
|
12 |
+
graph_builder.add_node("documenting", data_node)
|
13 |
+
|
14 |
+
# Define edges
|
15 |
+
graph_builder.add_conditional_edges("chatbot_healthassistant", maybe_route_to_tools)
|
16 |
+
graph_builder.add_conditional_edges("patient", maybe_exit_human_node)
|
17 |
+
graph_builder.add_edge("documenting", "chatbot_healthassistant")
|
18 |
+
graph_builder.add_edge(START, "chatbot_healthassistant")
|
19 |
+
|
20 |
+
# Compile the state graph into a runnable object
|
21 |
+
graph_with_order_tools = graph_builder.compile()
|
22 |
+
|
23 |
+
def invoke_our_graph(st_messages, callables):
|
24 |
+
if not isinstance(callables, list):
|
25 |
+
raise TypeError("callables must be a list")
|
26 |
+
return graph_with_order_tools.invoke({"messages": st_messages}, config={"callbacks": callables})
|
modules/st_callable_util.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, TypeVar, Any, Dict, Optional
|
2 |
+
import inspect
|
3 |
+
|
4 |
+
from streamlit.runtime.scriptrunner import add_script_run_ctx, get_script_run_ctx
|
5 |
+
from streamlit.delta_generator import DeltaGenerator
|
6 |
+
|
7 |
+
from langchain_core.callbacks.base import BaseCallbackHandler
|
8 |
+
import streamlit as st
|
9 |
+
|
10 |
+
|
11 |
+
# Define a function to create a callback handler for Streamlit that updates the UI dynamically
|
12 |
+
def get_streamlit_cb(parent_container: DeltaGenerator) -> BaseCallbackHandler:
|
13 |
+
"""
|
14 |
+
Creates a Streamlit callback handler that updates the provided Streamlit container with new tokens.
|
15 |
+
Args:
|
16 |
+
parent_container (DeltaGenerator): The Streamlit container where the text will be rendered.
|
17 |
+
Returns:
|
18 |
+
BaseCallbackHandler: An instance of a callback handler configured for Streamlit.
|
19 |
+
"""
|
20 |
+
|
21 |
+
# Define a custom callback handler class for managing and displaying stream events in Streamlit
|
22 |
+
class StreamHandler(BaseCallbackHandler):
|
23 |
+
"""
|
24 |
+
Custom callback handler for Streamlit that updates a Streamlit container with new tokens.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, container: st.delta_generator.DeltaGenerator, initial_text: str = ""):
|
28 |
+
"""
|
29 |
+
Initializes the StreamHandler with a Streamlit container and optional initial text.
|
30 |
+
Args:
|
31 |
+
container (st.delta_generator.DeltaGenerator): The Streamlit container where text will be rendered.
|
32 |
+
initial_text (str): Optional initial text to start with in the container.
|
33 |
+
"""
|
34 |
+
self.container = container # The Streamlit container to update
|
35 |
+
self.thoughts_placeholder = self.container.container() # container to hold tool_call renders
|
36 |
+
self.tool_output_placeholder = None # placeholder for the output of the tool call to be in the expander
|
37 |
+
self.token_placeholder = self.container.empty() # for token streaming
|
38 |
+
self.text = initial_text # The text content to display, starting with initial text
|
39 |
+
|
40 |
+
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
41 |
+
"""
|
42 |
+
Callback method triggered when a new token is received (e.g., from a language model).
|
43 |
+
Args:
|
44 |
+
token (str): The new token received.
|
45 |
+
**kwargs: Additional keyword arguments.
|
46 |
+
"""
|
47 |
+
self.text += token # Append the new token to the existing text
|
48 |
+
self.token_placeholder.write(self.text)
|
49 |
+
|
50 |
+
def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> None:
|
51 |
+
"""
|
52 |
+
Run when the tool starts running.
|
53 |
+
Args:
|
54 |
+
serialized (Dict[str, Any]): The serialized tool.
|
55 |
+
input_str (str): The input string.
|
56 |
+
kwargs (Any): Additional keyword arguments.
|
57 |
+
"""
|
58 |
+
with self.thoughts_placeholder:
|
59 |
+
status_placeholder = st.empty() # Placeholder to show the tool's status
|
60 |
+
with status_placeholder.status("Calling Tool...", expanded=True) as s:
|
61 |
+
st.write("called ", serialized["name"]) # Show which tool is being called
|
62 |
+
st.write("tool description: ", serialized["description"])
|
63 |
+
st.write("tool input: ")
|
64 |
+
st.code(input_str) # Display the input data sent to the tool
|
65 |
+
st.write("tool output: ")
|
66 |
+
# Placeholder for tool output that will be updated later below
|
67 |
+
self.tool_output_placeholder = st.empty()
|
68 |
+
s.update(label="Completed Calling Tool!", expanded=False) # Update the status once done
|
69 |
+
|
70 |
+
def on_tool_end(self, output: Any, **kwargs: Any) -> Any:
|
71 |
+
"""
|
72 |
+
Run when the tool ends.
|
73 |
+
Args:
|
74 |
+
output (Any): The output from the tool.
|
75 |
+
kwargs (Any): Additional keyword arguments.
|
76 |
+
"""
|
77 |
+
# We assume that `on_tool_end` comes after `on_tool_start`, meaning output_placeholder exists
|
78 |
+
if self.tool_output_placeholder:
|
79 |
+
self.tool_output_placeholder.code(output.content) # Display the tool's output
|
80 |
+
|
81 |
+
# Define a type variable for generic type hinting in the decorator, to maintain
|
82 |
+
# input function and wrapped function return type
|
83 |
+
fn_return_type = TypeVar('fn_return_type')
|
84 |
+
|
85 |
+
# Decorator function to add the Streamlit execution context to a function
|
86 |
+
def add_streamlit_context(fn: Callable[..., fn_return_type]) -> Callable[..., fn_return_type]:
|
87 |
+
"""
|
88 |
+
Decorator to ensure that the decorated function runs within the Streamlit execution context.
|
89 |
+
Args:
|
90 |
+
fn (Callable[..., fn_return_type]): The function to be decorated.
|
91 |
+
Returns:
|
92 |
+
Callable[..., fn_return_type]: The decorated function that includes the Streamlit context setup.
|
93 |
+
"""
|
94 |
+
ctx = get_script_run_ctx() # Retrieve the current Streamlit script execution context
|
95 |
+
|
96 |
+
def wrapper(*args, **kwargs) -> fn_return_type:
|
97 |
+
"""
|
98 |
+
Wrapper function that adds the Streamlit context and then calls the original function.
|
99 |
+
Args:
|
100 |
+
*args: Positional arguments to pass to the original function.
|
101 |
+
**kwargs: Keyword arguments to pass to the original function.
|
102 |
+
Returns:
|
103 |
+
fn_return_type: The result from the original function.
|
104 |
+
"""
|
105 |
+
add_script_run_ctx(ctx=ctx) # Add the Streamlit context to the current execution
|
106 |
+
return fn(*args, **kwargs) # Call the original function with its arguments
|
107 |
+
|
108 |
+
return wrapper
|
109 |
+
|
110 |
+
# Create an instance of the custom StreamHandler with the provided Streamlit container
|
111 |
+
st_cb = StreamHandler(parent_container)
|
112 |
+
|
113 |
+
# Iterate over all methods of the StreamHandler instance
|
114 |
+
for method_name, method_func in inspect.getmembers(st_cb, predicate=inspect.ismethod):
|
115 |
+
if method_name.startswith('on_'): # Identify callback methods
|
116 |
+
setattr(st_cb, method_name, add_streamlit_context(method_func)) # Wrap and replace the method
|
117 |
+
|
118 |
+
# Return the fully configured StreamHandler instance with the context-aware callback methods
|
119 |
+
return st_cb
|