Spaces:
Sleeping
Sleeping
File size: 10,230 Bytes
ca8935b 7d7cace 0692ff6 7d294f3 ee99cbb 0692ff6 e40b8c7 eb42ed9 7d294f3 e40b8c7 eb42ed9 3d8565e ace10d3 341c82c 3d8565e c3c4747 fb2c72e eb42ed9 00e8820 eb42ed9 00e8820 538187a 00e8820 1dc15ce 381fb97 1dc15ce c3c4747 de1ba67 6c07144 de1ba67 0692ff6 a526566 0692ff6 5f321b4 7d294f3 5d6578f 35b583e 848db73 c3c4747 5d6578f 35b583e 374569e 01d6892 5d6578f 35b583e 381fb97 01d6892 5d6578f 35b583e 5450f7d c892be7 122cdab 55c2f3f c892be7 ee99cbb d68882b ee99cbb c892be7 ccb54a0 19dad51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 |
import streamlit as st
import os
from streamlit_extras.stylable_container import stylable_container
from PIL import Image
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
# Variables used Globally
path = "/data" #preset path for hugging face spaces for persistent storage and cannot be changed
model_id="mistralai/Mistral-7B-Instruct-v0.3"
# Configure the Streamlit app
st.set_page_config(page_title="CRIStine", page_icon = "π©βπ§")
st.title("CRIStine - Interactive CRIS Assistant")
st.markdown(f"*This is a chatbot that uses the HuggingFace transformers with Retrieval Augmented Generation to guide and train users. It uses the {model_id}.*")
# Application Functions
# File Loader
@st.dialog("Upload a File")
def upload_file():
uploaded_file = st.file_uploader("Choose a file")
if uploaded_file is not None:
file_details = {"FileName":uploaded_file.name,"FileType":uploaded_file.type}
st.write(file_details)
with open(os.path.join(path,uploaded_file.name),"wb") as f:
f.write(uploaded_file.getbuffer())
st.success("Saved File")
# File Delete
@st.dialog("Delete a File")
def delete_file():
# List all files in directory and subdirectories as buttons
for root, dirs, file_names in os.walk(path):
for file_name in file_names:
if st.button(file_name):
os.remove(os.path.join(path,file_name))
st.success("Removed File")
st.rerun()
# File View
@st.dialog("Files used by AI")
def view_file():
# List all files in directory and subdirectories
files = []
for root, dirs, file_names in os.walk(path):
for file_name in file_names:
files.append(file_name)
st.write(files)
if st.button("Close"):
st.rerun()
logo_column, space_column, upload_column, delete_column, browse_column, recycle_column = st.columns(6)
st.markdown(
'<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/css/all.min.css"/> ',
unsafe_allow_html=True,
)
with logo_column:
image = Image.open(os.path.join(path,'CRIStine.png'))
st.image(image, caption='CRIStine')
with upload_column:
with stylable_container(
key="upload_button",
css_styles=r"""
button p:before {
font-family: 'Font Awesome 5 Free';
content: '\f574';
display: inline-block;
padding-right: 3px;
vertical-align: middle;
font-weight: 900;
}
""",
):
if st.button("Upload", key='upload'):
upload_file()
with delete_column:
with stylable_container(
key="delete_button",
css_styles=r"""
button p:before {
font-family: 'Font Awesome 5 Free';
content: '\f1c3';
display: inline-block;
padding-right: 3px;
vertical-align: middle;
font-weight: 900;
}
""",
):
if st.button("Delete", key='delete'):
delete_file()
with browse_column:
with stylable_container(
key="view_button",
css_styles=r"""
button p:before {
font-family: 'Font Awesome 5 Free';
content: '\f07c';
display: inline-block;
padding-right: 3px;
vertical-align: middle;
font-weight: 900;
}
""",
):
if st.button("View", key='view'):
view_file()
with recycle_column:
with stylable_container(
key="recycle_button",
css_styles=r"""
button p:before {
font-family: 'Font Awesome 5 Free';
content: '\f1b8';
display: inline-block;
padding-right: 3px;
vertical-align: middle;
font-weight: 900;
}
""",
):
st.button("Recycle", key='recycle')
# Main app goes below here -
def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.1):
"""
Returns a language model for HuggingFace inference.
Parameters:
- model_id (str): The ID of the HuggingFace model repository.
- max_new_tokens (int): The maximum number of new tokens to generate.
- temperature (float): The temperature for sampling from the model.
Returns:
- llm (HuggingFaceEndpoint): The language model for HuggingFace inference.
"""
llm = HuggingFaceEndpoint(
repo_id=model_id,
max_new_tokens=max_new_tokens,
temperature=temperature,
token = os.getenv("HF_TOKEN")
)
return llm
# Initialize session state for avatars
if "avatars" not in st.session_state:
st.session_state.avatars = {'user': None, 'assistant': None}
# Initialize session state for user text input
if 'user_text' not in st.session_state:
st.session_state.user_text = None
# Initialize session state for model parameters
if "max_response_length" not in st.session_state:
st.session_state.max_response_length = 256
if "system_message" not in st.session_state:
st.session_state.system_message = "friendly AI conversing with a human user"
if "starter_message" not in st.session_state:
st.session_state.starter_message = "Hello, there! How can I help you today?"
# Sidebar for settings
with st.sidebar:
st.header("System Settings")
# AI Settings
st.session_state.system_message = st.text_area(
"System Message", value="You are a friendly AI conversing with a human user."
)
st.session_state.starter_message = st.text_area(
'First AI Message', value="Hello, there! How can I help you today?"
)
# Model Settings
st.session_state.max_response_length = st.number_input(
"Max Response Length", value=128
)
# Avatar Selection
st.markdown("*Select Avatars:*")
col1, col2 = st.columns(2)
with col1:
st.session_state.avatars['assistant'] = st.selectbox(
"AI Avatar", options=["π€", "π¬", "π€"], index=0
)
with col2:
st.session_state.avatars['user'] = st.selectbox(
"User Avatar", options=["π€", "π±ββοΈ", "π¨πΎ", "π©", "π§πΎ"], index=0
)
# Reset Chat History
reset_history = st.button("Reset Chat History")
# Initialize or reset chat history
if "chat_history" not in st.session_state or reset_history:
st.session_state.chat_history = [{"role": "assistant", "content": st.session_state.starter_message}]
def get_response(system_message, chat_history, user_text,
eos_token_id=['User'], max_new_tokens=256, get_llm_hf_kws={}):
"""
Generates a response from the chatbot model.
Args:
system_message (str): The system message for the conversation.
chat_history (list): The list of previous chat messages.
user_text (str): The user's input text.
model_id (str, optional): The ID of the HuggingFace model to use.
eos_token_id (list, optional): The list of end-of-sentence token IDs.
max_new_tokens (int, optional): The maximum number of new tokens to generate.
get_llm_hf_kws (dict, optional): Additional keyword arguments for the get_llm_hf function.
Returns:
tuple: A tuple containing the generated response and the updated chat history.
"""
# Set up the model
hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)
# Create the prompt template
prompt = PromptTemplate.from_template(
(
"[INST] {system_message}"
"\nCurrent Conversation:\n{chat_history}\n\n"
"\nUser: {user_text}.\n [/INST]"
"\nAI:"
)
)
# Make the chain and bind the prompt
chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
# Generate the response
response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
response = response.split("AI:")[-1]
# Update the chat history
chat_history.append({'role': 'user', 'content': user_text})
chat_history.append({'role': 'assistant', 'content': response})
return response, chat_history
# Chat interface
chat_interface = st.container(border=True)
with chat_interface:
output_container = st.container()
st.session_state.user_text = st.chat_input(placeholder="Enter your text here.")
# Display chat messages
with output_container:
# For every message in the history
for message in st.session_state.chat_history:
# Skip the system message
if message['role'] == 'system':
continue
# Display the chat message using the correct avatar
with st.chat_message(message['role'],
avatar=st.session_state['avatars'][message['role']]):
st.markdown(message['content'])
# When the user enter new text:
if st.session_state.user_text:
# Display the user's new message immediately
with st.chat_message("user",
avatar=st.session_state.avatars['user']):
st.markdown(st.session_state.user_text)
# Display a spinner status bar while waiting for the response
with st.chat_message("assistant",
avatar=st.session_state.avatars['assistant']):
with st.spinner("Thinking..."):
# Call the Inference API with the system_prompt, user text, and history
response, st.session_state.chat_history = get_response(
system_message=st.session_state.system_message,
user_text=st.session_state.user_text,
chat_history=st.session_state.chat_history,
max_new_tokens=st.session_state.max_response_length,
)
st.markdown(response)
|