John Graham Reynolds
with HF GPUs are available again, put model and inputs on GPU
f163dd7
raw
history blame
11 kB
import streamlit as st
from itertools import tee
from model import InferenceBuilder
# pip.main(['install', './apex-0.1-cp311-cp311-linux_x86_64.whl']) # install the apex package from wheel since building takes forever
# import apex
MODEL_AVATAR_URL= "./iphone_robot.png"
MAX_CHAT_TURNS = 10 # limit this for preliminary testing
MSG_MAX_TURNS_EXCEEDED = f"Sorry! The CyberSolve LinAlg playground is limited to {MAX_CHAT_TURNS} turns in a single history. Click the 'Clear Chat' button or refresh the page to start a new conversation."
# MSG_CLIPPED_AT_MAX_OUT_TOKENS = "Reached maximum output tokens for DBRX Playground"
EXAMPLE_PROMPTS = [
"Solve 24 = 1601c - 1605c for c.",
"Solve 657 = -220*t + 1086*t + 22307 for t.",
"Solve -11*y - 263*y + 3162 = -88*y for y.",
"Solve 0 = -11*b - 4148 + 4225 for b.",
"Solve 65*l - 361 + 881 = 0 for l.",
"Solve 49*l + 45*l - 125 - 63 = 0 for l.",
]
TITLE = "CyberSolve LinAlg 1.2"
DESCRIPTION= """Welcome to the πŸ€–CyberSolve LinAlg 1.2🧠 demo! \n
**Overview and Usage**: This πŸ€— Space is designed to demo the abilities of the **CyberSolve LinAlg 1.2** text-to-text language model. Model card: *MarioBarbeque/CyberSolve-LinAlg-1.2*
Specifically, the **CyberSolve LinAlg 1.x** family of models
are downstream versions of the 783M parameter FLAN-T5 text-to-text transformer, fine-tuned on the Google DeepMind Mathematics dataset for the purpose of solving linear equations of a single variable.
To effectively query the model for its intended task, prompt the model to solve an arbitrary linear equation of a single variable with a query of the form: *"Solve 24 = 1601c - 1605c for c."*; the model
will return its prediciton in a simple format. The algebraic capabailites of CyberSolve far exceed those of the base FLAN-T5 model. CyberSolve LinAlg 1.2 achieves a 90.7 percent exact match benchmark
on the DeepMind Mathematics evaluation dataset of 10,000 unique linear equations; the FLAN-T5 base model scores 9.6 percent.
On the left is a sidebar of **Examples** that can be clicked to query to model.
**Feedback**: Feedback is welcomed, encouraged, and invaluable! To give feedback in regards to one of the model's responses, click the **Give Feedback on Last Response** button just below
the user input bar. This allows you to provide either positive or negative feedback in regards to the model's most recent response. A **Feedback Form** will appear above the model's title.
Please be sure to select either πŸ‘ or πŸ‘Ž before adding additional notes about your choice. Be as brief or as detailed as you like! Note that you are making a difference; this
feedback allows us to later improve this model for your usage through a training technique known as reinforcement learning through human feedback. \n
Please provide any additional, larger feedback, ideas, or issues to the email: **[email protected]**. Happy inference!"""
GENERAL_ERROR_MSG = "An error occurred. Please refresh the page to start a new conversation."
# # To prevent streaming too fast, chunk the output into TOKEN_CHUNK_SIZE chunks
TOKEN_CHUNK_SIZE = 1 # test this number
# if TOKEN_CHUNK_SIZE_ENV is not None:
# TOKEN_CHUNK_SIZE = int(TOKEN_CHUNK_SIZE_ENV)
QUEUE_SIZE = 20 # maximize this value for adding enough places in the global queue?
# if QUEUE_SIZE_ENV is not None:
# QUEUE_SIZE = int(QUEUE_SIZE_ENV)
# @st.cache_resource
# def get_global_semaphore():
# return threading.BoundedSemaphore(QUEUE_SIZE)
# global_semaphore = get_global_semaphore()
st.set_page_config(layout="wide")
# url = "https://huggingface.co/MarioBarbeque/CyberSolve-LinAlg-1.2"
st.title(TITLE)
st.image("calabi_yau.jpeg", caption="Teaching AI to understand Mathematics", width=400) # TODO add a Vanderbilt related picture to the head of our Space!
# st.markdown(DESCRIPTION % url)
st.markdown(DESCRIPTION)
st.markdown("\n")
# use this to format later
with open("./style.css") as css:
st.markdown( f'<style>{css.read()}</style>' , unsafe_allow_html= True)
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "feedback" not in st.session_state:
st.session_state["feedback"] = [None]
def clear_chat_history():
st.session_state["messages"] = []
st.button('Clear Chat', on_click=clear_chat_history)
# build our chain outside the working body so that its only instantiated once - simply pass it the chat history for chat completion
builder = InferenceBuilder()
tokenizer = builder.load_tokenizer()
model = builder.load_model()
def last_role_is_user():
return len(st.session_state["messages"]) > 0 and st.session_state["messages"][-1]["role"] == "user"
def get_last_question():
return st.session_state["messages"][-1]["content"]
def text_stream(stream):
for chunk in stream:
if chunk["content"] is not None:
yield chunk["content"]
def get_stream_warning_error(stream):
error = None
warning = None
for chunk in stream:
if chunk["error"] is not None:
error = chunk["error"]
if chunk["warning"] is not None:
warning = chunk["warning"]
return warning, error
# # @retry(wait=wait_random_exponential(min=0.5, max=2), stop=stop_after_attempt(3))
# def chain_call(history):
# input = {'messages': [{"role": m["role"], "content": m["content"]} for m in history]}
# chat_completion = chain.stream(input)
# return chat_completion
def model_inference(messages):
input_ids = tokenizer(get_last_question(), return_tensors="pt").input_ids.to("cuda") # tokenize the input and put it on the GPU
# input_ids = tokenizer(get_last_question(), return_tensors="pt").input_ids # testing on CPU
outputs = model.generate(input_ids)
for chunk in tokenizer.decode(outputs[0], skip_special_tokens=True):
yield chunk # yield each chunk of the predicted string character by character
def write_response():
stream = chat_completion(st.session_state["messages"])
content_stream, error_stream = tee(stream)
response = st.write_stream(text_stream(content_stream))
stream_warning, stream_error = get_stream_warning_error(error_stream)
if stream_warning is not None:
st.warning(stream_warning,icon="⚠️")
if stream_error is not None:
st.error(stream_error,icon="🚨")
# if there was an error, a list will be returned instead of a string: https://docs.streamlit.io/library/api-reference/write-magic/st.write_stream
if isinstance(response, list):
response = None
return response, stream_warning, stream_error
def chat_completion(messages):
if (len(messages)-1)//2 >= MAX_CHAT_TURNS:
yield {"content": None, "error": MSG_MAX_TURNS_EXCEEDED, "warning": None}
return
chat_completion = None
error = None
# *** TODO add code for implementing a global queue with a bounded semaphore?
# wait to be in queue
# with global_semaphore:
# try:
# chat_completion = chat_api_call(history_dbrx_format)
# except Exception as e:
# error = e
# chat_completion = chain_call(history_dbrx_format)
chat_completion = model_inference(messages)
if error is not None:
yield {"content": None, "error": GENERAL_ERROR_MSG, "warning": None}
print(error)
return
max_token_warning = None
partial_message = ""
chunk_counter = 0
for chunk in chat_completion:
if chunk is not None:
chunk_counter += 1
partial_message += chunk
if chunk_counter % TOKEN_CHUNK_SIZE == 0:
chunk_counter = 0
yield {"content": partial_message, "error": None, "warning": None}
partial_message = ""
# if chunk.choices[0].finish_reason == "length":
# max_token_warning = MSG_CLIPPED_AT_MAX_OUT_TOKENS
yield {"content": partial_message, "error": None, "warning": max_token_warning}
# if assistant is the last message, we need to prompt the user
# if user is the last message, we need to retry the assistant.
def handle_user_input(user_input):
with history:
response, stream_warning, stream_error = [None, None, None]
if last_role_is_user():
# retry the assistant if the user tries to send a new message
with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
response, stream_warning, stream_error = write_response()
else:
st.session_state["messages"].append({"role": "user", "content": user_input, "warning": None, "error": None})
with st.chat_message("user", avatar="πŸ§‘β€πŸ’»"):
st.markdown(user_input)
# stream = chat_completion(st.session_state["messages"])
with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
response, stream_warning, stream_error = write_response()
st.session_state["messages"].append({"role": "assistant", "content": response, "warning": stream_warning, "error": stream_error})
def feedback():
with st.form("feedback_form"):
st.title("Feedback Form")
st.markdown("Please select either πŸ‘ or πŸ‘Ž before providing a reason for your review of the most recent response. Dont forget to click submit!")
rating = st.feedback()
feedback = st.text_input("Please detail your feedback: ")
# implement a method for writing these responses to storage!
submitted = st.form_submit_button("Submit Feedback")
main = st.container()
with main:
if st.session_state["feedback"][-1] is not None: # TODO clean this up in a fn?
st.markdown("Thank you! Feedback received! Type a new message to continue your conversation.")
history = st.container(height=400)
with history:
for message in st.session_state["messages"]:
avatar = "πŸ§‘β€πŸ’»"
if message["role"] == "assistant":
avatar = MODEL_AVATAR_URL
with st.chat_message(message["role"], avatar=avatar):
if message["content"] is not None:
st.markdown(message["content"])
if message["error"] is not None:
st.error(message["error"],icon="🚨")
if message["warning"] is not None:
st.warning(message["warning"],icon="⚠️")
if prompt := st.chat_input("Type a message!", max_chars=5000):
handle_user_input(prompt)
st.markdown("\n") #add some space for iphone users
gave_feedback = st.button('Give Feedback on Last Response', on_click=feedback)
if gave_feedback: # TODO clean up the conditions here with a function
st.session_state["feedback"].append("given")
else:
st.session_state["feedback"].append(None)
with st.sidebar:
with st.container():
st.title("Examples")
for prompt in EXAMPLE_PROMPTS:
st.button(prompt, args=(prompt,), on_click=handle_user_input)