|
import streamlit as st |
|
from io import BytesIO |
|
import ibm_watsonx_ai |
|
import secretsload |
|
import genparam |
|
import requests |
|
import time |
|
import re |
|
from ibm_watsonx_ai.foundation_models import ModelInference |
|
from ibm_watsonx_ai import Credentials, APIClient |
|
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams |
|
from ibm_watsonx_ai.metanames import GenTextReturnOptMetaNames as RetParams |
|
from secretsload import load_stsecrets |
|
|
|
credentials = load_stsecrets() |
|
|
|
st.set_page_config( |
|
page_title="Jimmy", |
|
page_icon="😒", |
|
initial_sidebar_state="collapsed" |
|
) |
|
|
|
|
|
def check_password(): |
|
def password_entered(): |
|
if st.session_state["password"] == st.secrets["app_password"]: |
|
st.session_state["password_correct"] = True |
|
del st.session_state["password"] |
|
else: |
|
st.session_state["password_correct"] = False |
|
|
|
if "password_correct" not in st.session_state: |
|
st.markdown("\n\n") |
|
st.text_input("Enter the password", type="password", on_change=password_entered, key="password") |
|
st.divider() |
|
st.info("Developed by Milan Mrdenovic © IBM Norway 2024") |
|
return False |
|
elif not st.session_state["password_correct"]: |
|
st.markdown("\n\n") |
|
st.text_input("Enter the password", type="password", on_change=password_entered, key="password") |
|
st.divider() |
|
st.info("Developed by Milan Mrdenovic © IBM Norway 2024") |
|
st.error("😕 Password incorrect") |
|
return False |
|
else: |
|
return True |
|
|
|
if not check_password(): |
|
st.stop() |
|
|
|
|
|
|
|
if 'current_page' not in st.session_state: |
|
st.session_state.current_page = 0 |
|
|
|
def initialize_session_state(): |
|
if 'chat_history' not in st.session_state: |
|
st.session_state.chat_history = [] |
|
|
|
def setup_client(): |
|
credentials = Credentials( |
|
url=st.secrets["url"], |
|
api_key=st.secrets["api_key"] |
|
) |
|
return APIClient(credentials, project_id=st.secrets["project_id"]) |
|
|
|
def prepare_prompt(prompt, chat_history): |
|
if genparam.TYPE == "chat" and chat_history: |
|
chats = "\n".join([f"{message['role']}: \"{message['content']}\"" for message in chat_history]) |
|
return f"Conversation History:\n{chats}\n\nNew Message: {prompt}" |
|
return prompt |
|
|
|
def apply_prompt_syntax(prompt, system_prompt, prompt_template, bake_in_prompt_syntax): |
|
model_family_syntax = { |
|
"llama3-instruct (llama-3 & 3.1) - system": """\n<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""", |
|
"llama3-instruct (llama-3 & 3.1) - user": """\n<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""", |
|
"granite-13b-chat & instruct - system": """\n<|system|>\n{system_prompt}\n<|user|>\n{prompt}\n<|assistant|>\n\n""", |
|
"granite-13b-chat & instruct - user": """\n<|user|>\n{prompt}\n<|assistant|>\n\n""", |
|
"llama2-chat - system": """\n[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{prompt} [/INST] """, |
|
"llama2-chat - user": """\n[INST] {prompt} [/INST] """, |
|
"mistral & mixtral v2 tokenizer - system": """\n<s>[INST] System Prompt:[{system_prompt}]\n\n{prompt} [/INST] """, |
|
"mistral & mixtral v2 tokenizer - user": """\n<s>[INST] {prompt} [/INST] """ |
|
} |
|
|
|
if bake_in_prompt_syntax: |
|
template = model_family_syntax[prompt_template] |
|
if system_prompt: |
|
return template.format(system_prompt=system_prompt, prompt=prompt) |
|
return prompt |
|
|
|
def generate_response(watsonx_llm, prompt_data, params): |
|
generated_response = watsonx_llm.generate_text_stream(prompt=prompt_data, params=params) |
|
for chunk in generated_response: |
|
yield chunk |
|
|
|
def chat_interface(): |
|
st.subheader("Jimmy") |
|
|
|
|
|
user_input = st.chat_input("You:", key="user_input") |
|
|
|
if user_input: |
|
|
|
st.session_state.chat_history.append({"role": "user", "content": user_input}) |
|
|
|
|
|
prompt = prepare_prompt(user_input, st.session_state.chat_history) |
|
|
|
|
|
prompt_data = apply_prompt_syntax( |
|
prompt, |
|
genparam.SYSTEM_PROMPT, |
|
genparam.PROMPT_TEMPLATE, |
|
genparam.BAKE_IN_PROMPT_SYNTAX |
|
) |
|
|
|
|
|
client = setup_client() |
|
watsonx_llm = ModelInference( |
|
api_client=client, |
|
model_id=genparam.SELECTED_MODEL, |
|
verify=genparam.VERIFY |
|
) |
|
|
|
|
|
params = { |
|
GenParams.DECODING_METHOD: genparam.DECODING_METHOD, |
|
GenParams.MAX_NEW_TOKENS: genparam.MAX_NEW_TOKENS, |
|
GenParams.MIN_NEW_TOKENS: genparam.MIN_NEW_TOKENS, |
|
GenParams.REPETITION_PENALTY: genparam.REPETITION_PENALTY, |
|
GenParams.STOP_SEQUENCES: genparam.STOP_SEQUENCES |
|
} |
|
|
|
|
|
with st.chat_message("Jimmy", avatar="😒"): |
|
stream = generate_response(watsonx_llm, prompt_data, params) |
|
response = st.write_stream(stream) |
|
|
|
|
|
st.session_state.chat_history.append({"role": "Jimmy", "content": response}) |
|
|
|
def main(): |
|
initialize_session_state() |
|
chat_interface() |
|
|
|
if __name__ == "__main__": |
|
main() |