MilanM commited on
Commit
bcee819
·
verified ·
1 Parent(s): 6a0f397

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from io import BytesIO
3
+ import ibm_watsonx_ai
4
+ import secretsload
5
+ import genparam
6
+ import requests
7
+ import time
8
+ import re
9
+ from ibm_watsonx_ai.foundation_models import ModelInference
10
+ from ibm_watsonx_ai import Credentials, APIClient
11
+ from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
12
+ from ibm_watsonx_ai.metanames import GenTextReturnOptMetaNames as RetParams
13
+ from secretsload import load_stsecrets
14
+
15
+ credentials = load_stsecrets()
16
+ print(credentials)
17
+
18
+ st.set_page_config(
19
+ page_title="Jimmy the Jailbreak",
20
+ page_icon="🏴‍☠️",
21
+ initial_sidebar_state="collapsed"
22
+ )
23
+
24
+ # Password protection
25
+ def check_password():
26
+ def password_entered():
27
+ if st.session_state["password"] == st.secrets["app_password"]:
28
+ st.session_state["password_correct"] = True
29
+ del st.session_state["password"]
30
+ else:
31
+ st.session_state["password_correct"] = False
32
+
33
+ if "password_correct" not in st.session_state:
34
+ st.markdown("\n\n")
35
+ st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
36
+ st.divider()
37
+ st.info("Developed by Milan Mrdenovic © IBM Norway 2024")
38
+ return False
39
+ elif not st.session_state["password_correct"]:
40
+ st.markdown("\n\n")
41
+ st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
42
+ st.divider()
43
+ st.info("Developed by Milan Mrdenovic © IBM Norway 2024")
44
+ st.error("😕 Password incorrect")
45
+ return False
46
+ else:
47
+ return True
48
+
49
+ if not check_password():
50
+ st.stop()
51
+
52
+
53
+ # Initialize session state
54
+ if 'current_page' not in st.session_state:
55
+ st.session_state.current_page = 0
56
+
57
+ def initialize_session_state():
58
+ if 'chat_history' not in st.session_state:
59
+ st.session_state.chat_history = []
60
+
61
+ def setup_client():
62
+ credentials = Credentials(
63
+ url=st.secrets["url"],
64
+ api_key=st.secrets["api_key"]
65
+ )
66
+ apo = st.secrets["api_key"]
67
+ print(apo)
68
+ return APIClient(credentials, project_id=st.secrets["project_id"])
69
+
70
+ def prepare_prompt(prompt, chat_history):
71
+ if genparam.TYPE == "chat" and chat_history:
72
+ chats = "\n".join([f"{message['role']}: \"{message['content']}\"" for message in chat_history])
73
+ return f"Conversation History:\n{chats}\n\nNew Message: {prompt}"
74
+ return prompt
75
+
76
+ def apply_prompt_syntax(prompt, system_prompt, prompt_template, bake_in_prompt_syntax):
77
+ model_family_syntax = {
78
+ "llama3-instruct (llama-3 & 3.1) - system": """\n<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
79
+ "llama3-instruct (llama-3 & 3.1) - user": """\n<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
80
+ "granite-13b-chat & instruct - system": """\n<|system|>\n{system_prompt}\n<|user|>\n{prompt}\n<|assistant|>\n\n""",
81
+ "granite-13b-chat & instruct - user": """\n<|user|>\n{prompt}\n<|assistant|>\n\n""",
82
+ "llama2-chat - system": """\n[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{prompt} [/INST] """,
83
+ "llama2-chat - user": """\n[INST] {prompt} [/INST] """,
84
+ "mistral & mixtral v2 tokenizer - system": """\n<s>[INST] System Prompt:[{system_prompt}]\n\n{prompt} [/INST] """,
85
+ "mistral & mixtral v2 tokenizer - user": """\n<s>[INST] {prompt} [/INST] """
86
+ }
87
+
88
+ if bake_in_prompt_syntax:
89
+ template = model_family_syntax[prompt_template]
90
+ if system_prompt:
91
+ return template.format(system_prompt=system_prompt, prompt=prompt)
92
+ return prompt
93
+
94
+ def generate_response(watsonx_llm, prompt_data, params):
95
+ generated_response = watsonx_llm.generate_text_stream(prompt=prompt_data, params=params)
96
+ for chunk in generated_response:
97
+ yield chunk
98
+
99
+ def chat_interface():
100
+ st.title("Jimmy")
101
+
102
+ # User input
103
+ user_input = st.chat_input("You:", key="user_input")
104
+
105
+ if user_input:
106
+ # Add user message to chat history
107
+ st.session_state.chat_history.append({"role": "user", "content": user_input})
108
+
109
+ # Prepare the prompt
110
+ prompt = prepare_prompt(user_input, st.session_state.chat_history)
111
+
112
+ # Apply prompt syntax
113
+ prompt_data = apply_prompt_syntax(
114
+ prompt,
115
+ genparam.SYSTEM_PROMPT,
116
+ genparam.PROMPT_TEMPLATE,
117
+ genparam.BAKE_IN_PROMPT_SYNTAX
118
+ )
119
+
120
+ # Setup client and model
121
+ client = setup_client()
122
+ watsonx_llm = ModelInference(
123
+ api_client=client,
124
+ model_id=genparam.SELECTED_MODEL,
125
+ verify=genparam.VERIFY
126
+ )
127
+
128
+ # Prepare parameters
129
+ params = {
130
+ GenParams.DECODING_METHOD: genparam.DECODING_METHOD,
131
+ GenParams.MAX_NEW_TOKENS: genparam.MAX_NEW_TOKENS,
132
+ GenParams.MIN_NEW_TOKENS: genparam.MIN_NEW_TOKENS,
133
+ GenParams.REPETITION_PENALTY: genparam.REPETITION_PENALTY,
134
+ GenParams.STOP_SEQUENCES: genparam.STOP_SEQUENCES
135
+ }
136
+
137
+ # Generate and stream response
138
+ with st.chat_message("assistant"):
139
+ stream = generate_response(watsonx_llm, prompt_data, params)
140
+ response = st.write_stream(stream)
141
+
142
+ # Add AI response to chat history
143
+ st.session_state.chat_history.append({"role": "assistant", "content": response})
144
+
145
+ def main():
146
+ initialize_session_state()
147
+ chat_interface()
148
+
149
+ if __name__ == "__main__":
150
+ main()