File size: 10,230 Bytes
ca8935b
7d7cace
0692ff6
7d294f3
 
ee99cbb
 
 
 
0692ff6
e40b8c7
 
eb42ed9
7d294f3
e40b8c7
eb42ed9
3d8565e
ace10d3
341c82c
 
3d8565e
 
c3c4747
 
 
 
 
 
 
 
 
 
 
 
fb2c72e
eb42ed9
 
00e8820
eb42ed9
 
00e8820
538187a
 
 
00e8820
1dc15ce
 
381fb97
1dc15ce
 
 
 
 
 
 
 
c3c4747
de1ba67
6c07144
de1ba67
0692ff6
a526566
0692ff6
 
 
5f321b4
 
 
7d294f3
5d6578f
35b583e
 
 
 
 
 
 
 
 
 
 
 
 
848db73
c3c4747
 
5d6578f
35b583e
 
 
 
 
 
 
 
 
 
 
 
 
374569e
 
01d6892
5d6578f
35b583e
 
 
 
 
 
 
 
 
 
 
 
 
381fb97
 
01d6892
5d6578f
35b583e
 
 
 
 
 
 
 
 
 
 
 
 
5450f7d
c892be7
122cdab
55c2f3f
c892be7
ee99cbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d68882b
ee99cbb
c892be7
ccb54a0
 
 
 
19dad51
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import streamlit as st
import os
from streamlit_extras.stylable_container import stylable_container
from PIL import Image

from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser




# Variables used Globally
path = "/data"  #preset path for hugging face spaces for persistent storage and cannot be changed
model_id="mistralai/Mistral-7B-Instruct-v0.3"

# Configure the Streamlit app
st.set_page_config(page_title="CRIStine", page_icon = "πŸ‘©β€πŸ”§")
st.title("CRIStine - Interactive CRIS Assistant")
st.markdown(f"*This is a chatbot that uses the HuggingFace transformers with Retrieval Augmented Generation  to guide and train users. It uses the {model_id}.*")


# Application Functions
# File Loader
@st.dialog("Upload a File")
def upload_file():
    uploaded_file = st.file_uploader("Choose a file")
    if uploaded_file is not None:
        file_details = {"FileName":uploaded_file.name,"FileType":uploaded_file.type}
        st.write(file_details)
        with open(os.path.join(path,uploaded_file.name),"wb") as f: 
           f.write(uploaded_file.getbuffer())         
           st.success("Saved File")

# File Delete
@st.dialog("Delete a File")
def delete_file():
    # List all files in directory and subdirectories as buttons
    for root, dirs, file_names in os.walk(path):
        for file_name in file_names:
            if st.button(file_name):
                os.remove(os.path.join(path,file_name))
                st.success("Removed File")
                st.rerun()
       
# File View
@st.dialog("Files used by AI")
def view_file():
    # List all files in directory and subdirectories
    files = []
    for root, dirs, file_names in os.walk(path):
        for file_name in file_names:
            files.append(file_name)
    st.write(files)
    if st.button("Close"):
       st.rerun()


logo_column, space_column, upload_column, delete_column, browse_column, recycle_column  = st.columns(6)

st.markdown(
    '<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/css/all.min.css"/> ',
    unsafe_allow_html=True,
)

with logo_column:
    image = Image.open(os.path.join(path,'CRIStine.png'))
    st.image(image, caption='CRIStine')
    
with upload_column:
    with stylable_container(
        key="upload_button",
        css_styles=r"""
            button p:before {
                font-family: 'Font Awesome 5 Free';
                content: '\f574';
                display: inline-block;
                padding-right: 3px;
                vertical-align: middle;
                font-weight: 900;
            }
            """,
    ):
        if st.button("Upload",  key='upload'):
            upload_file()
 
with delete_column:
    with stylable_container(
        key="delete_button",
        css_styles=r"""
            button p:before {
                font-family: 'Font Awesome 5 Free';
                content: '\f1c3';
                display: inline-block;
                padding-right: 3px;
                vertical-align: middle;
                font-weight: 900;
            }
            """,
    ):
        if st.button("Delete",  key='delete'):
            delete_file()
   
with browse_column:
    with stylable_container(
        key="view_button",
        css_styles=r"""
            button p:before {
                font-family: 'Font Awesome 5 Free';
                content: '\f07c';
                display: inline-block;
                padding-right: 3px;
                vertical-align: middle;
                font-weight: 900;
            }
            """,
    ):
        if st.button("View",  key='view'):
            view_file()
 
with recycle_column:
    with stylable_container(
        key="recycle_button",
        css_styles=r"""
            button p:before {
                font-family: 'Font Awesome 5 Free';
                content: '\f1b8';
                display: inline-block;
                padding-right: 3px;
                vertical-align: middle;
                font-weight: 900;
            }
            """,
    ):
        st.button("Recycle",  key='recycle')

# Main app goes below here -
 


def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.1):
    """
    Returns a language model for HuggingFace inference.

    Parameters:
    - model_id (str): The ID of the HuggingFace model repository.
    - max_new_tokens (int): The maximum number of new tokens to generate.
    - temperature (float): The temperature for sampling from the model.

    Returns:
    - llm (HuggingFaceEndpoint): The language model for HuggingFace inference.
    """
    llm = HuggingFaceEndpoint(
        repo_id=model_id,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        token = os.getenv("HF_TOKEN")
    )
    return llm

# Initialize session state for avatars
if "avatars" not in st.session_state:
    st.session_state.avatars = {'user': None, 'assistant': None}

# Initialize session state for user text input
if 'user_text' not in st.session_state:
    st.session_state.user_text = None

# Initialize session state for model parameters
if "max_response_length" not in st.session_state:
    st.session_state.max_response_length = 256

if "system_message" not in st.session_state:
    st.session_state.system_message = "friendly AI conversing with a human user"

if "starter_message" not in st.session_state:
    st.session_state.starter_message = "Hello, there! How can I help you today?"
    
    
# Sidebar for settings
with st.sidebar:
    st.header("System Settings")

    # AI Settings
    st.session_state.system_message = st.text_area(
        "System Message", value="You are a friendly AI conversing with a human user."
    )
    st.session_state.starter_message = st.text_area(
        'First AI Message', value="Hello, there! How can I help you today?"
    )

    # Model Settings
    st.session_state.max_response_length = st.number_input(
        "Max Response Length", value=128
    )

    # Avatar Selection
    st.markdown("*Select Avatars:*")
    col1, col2 = st.columns(2)
    with col1:
        st.session_state.avatars['assistant'] = st.selectbox(
            "AI Avatar", options=["πŸ€—", "πŸ’¬", "πŸ€–"], index=0
        )
    with col2:
        st.session_state.avatars['user'] = st.selectbox(
            "User Avatar", options=["πŸ‘€", "πŸ‘±β€β™‚οΈ", "πŸ‘¨πŸΎ", "πŸ‘©", "πŸ‘§πŸΎ"], index=0
        )
    # Reset Chat History
    reset_history = st.button("Reset Chat History")
    
# Initialize or reset chat history
if "chat_history" not in st.session_state or reset_history:
    st.session_state.chat_history = [{"role": "assistant", "content": st.session_state.starter_message}]

def get_response(system_message, chat_history, user_text, 
                 eos_token_id=['User'], max_new_tokens=256, get_llm_hf_kws={}):
    """
    Generates a response from the chatbot model.

    Args:
        system_message (str): The system message for the conversation.
        chat_history (list): The list of previous chat messages.
        user_text (str): The user's input text.
        model_id (str, optional): The ID of the HuggingFace model to use.
        eos_token_id (list, optional): The list of end-of-sentence token IDs.
        max_new_tokens (int, optional): The maximum number of new tokens to generate.
        get_llm_hf_kws (dict, optional): Additional keyword arguments for the get_llm_hf function.

    Returns:
        tuple: A tuple containing the generated response and the updated chat history.
    """
    # Set up the model
    hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)

    # Create the prompt template
    prompt = PromptTemplate.from_template(
        (
            "[INST] {system_message}"
            "\nCurrent Conversation:\n{chat_history}\n\n"
            "\nUser: {user_text}.\n [/INST]"
            "\nAI:"
        )
    )
    # Make the chain and bind the prompt
    chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')

    # Generate the response
    response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
    response = response.split("AI:")[-1]

    # Update the chat history
    chat_history.append({'role': 'user', 'content': user_text})
    chat_history.append({'role': 'assistant', 'content': response})
    return response, chat_history

# Chat interface
chat_interface = st.container(border=True)
with chat_interface:
    output_container = st.container()
    st.session_state.user_text = st.chat_input(placeholder="Enter your text here.")
    
# Display chat messages
with output_container:
    # For every message in the history
    for message in st.session_state.chat_history:
        # Skip the system message
        if message['role'] == 'system':
            continue
            
        # Display the chat message using the correct avatar
        with st.chat_message(message['role'], 
                             avatar=st.session_state['avatars'][message['role']]):
            st.markdown(message['content'])
            
 # When the user enter new text:
    if st.session_state.user_text:
        
        # Display the user's new message immediately
        with st.chat_message("user", 
                             avatar=st.session_state.avatars['user']):
            st.markdown(st.session_state.user_text)
            
        # Display a spinner status bar while waiting for the response
        with st.chat_message("assistant", 
                             avatar=st.session_state.avatars['assistant']):

            with st.spinner("Thinking..."):
                # Call the Inference API with the system_prompt, user text, and history
                response, st.session_state.chat_history = get_response(
                    system_message=st.session_state.system_message, 
                    user_text=st.session_state.user_text,
                    chat_history=st.session_state.chat_history,
                    max_new_tokens=st.session_state.max_response_length,
                )
                st.markdown(response)