seanpedrickcase commited on
Commit
ee7464e
·
1 Parent(s): 0c818aa

Initial compatibility tested for use with Gemini and AWS Bedrock APIs

Browse files
.dockerignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pyc
2
+ *.ipynb
3
+ *.pdf
4
+ *.spec
5
+ *.toc
6
+ *.csv
7
+ *.bin
8
+ bootstrapper.py
9
+ build/*
10
+ dist/*
11
+ test/*
12
+ config/*
.gitignore CHANGED
@@ -8,4 +8,5 @@
8
  bootstrapper.py
9
  build/*
10
  dist/*
11
- test/*
 
 
8
  bootstrapper.py
9
  build/*
10
  dist/*
11
+ test/*
12
+ config/*
app.py CHANGED
@@ -1,44 +1,35 @@
1
- # Load in packages
2
-
3
  import os
4
- import socket
5
-
6
  from typing import Type
7
- from langchain_huggingface.embeddings import HuggingFaceEmbeddings#, HuggingFaceInstructEmbeddings
8
  from langchain_community.vectorstores import FAISS
9
  import gradio as gr
10
  import pandas as pd
11
 
12
- from transformers import AutoTokenizer
13
- import torch
14
-
15
- from llama_cpp import Llama
16
- from huggingface_hub import hf_hub_download
17
  from chatfuncs.ingest import embed_faiss_save_to_zip
18
- from chatfuncs.helper_functions import get_or_create_env_var
19
 
20
- from chatfuncs.helper_functions import ensure_output_folder_exists, get_connection_params, output_folder, get_or_create_env_var, reveal_feedback_buttons, wipe_logs
21
  from chatfuncs.aws_functions import upload_file_to_s3
22
- #from chatfuncs.llm_api_call import llm_query
23
  from chatfuncs.auth import authenticate_user
 
 
 
 
 
 
24
 
25
  PandasDataFrame = Type[pd.DataFrame]
26
 
27
  from datetime import datetime
28
  today_rev = datetime.now().strftime("%Y%m%d")
29
 
30
- ensure_output_folder_exists()
31
-
32
- host_name = socket.gethostname()
33
-
34
- access_logs_data_folder = 'logs/' + today_rev + '/' + host_name + '/'
35
- feedback_data_folder = 'feedback/' + today_rev + '/' + host_name + '/'
36
- usage_data_folder = 'usage/' + today_rev + '/' + host_name + '/'
37
 
38
  # Disable cuda devices if necessary
39
  #os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
40
 
41
- #from chatfuncs.chatfuncs import *
42
  import chatfuncs.ingest as ing
43
 
44
  ###
@@ -73,21 +64,44 @@ def get_faiss_store(faiss_vstore_folder,embeddings):
73
  return vectorstore
74
 
75
  import chatfuncs.chatfuncs as chatf
 
76
 
77
  chatf.embeddings = load_embeddings(embeddings_name)
78
  chatf.vectorstore = get_faiss_store(faiss_vstore_folder="faiss_embedding",embeddings=globals()["embeddings"])
79
 
 
80
 
81
- def load_model(model_type, gpu_layers, gpu_config=None, cpu_config=None, torch_device=None):
82
- print("Loading model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- # Default values inside the function
85
- if gpu_config is None:
86
- gpu_config = chatf.gpu_config
87
- if cpu_config is None:
88
- cpu_config = chatf.cpu_config
89
- if torch_device is None:
90
- torch_device = chatf.torch_device
91
 
92
  if model_type == "Phi 3.5 Mini (larger, slow)":
93
  if torch_device == "cuda":
@@ -112,8 +126,7 @@ def load_model(model_type, gpu_layers, gpu_config=None, cpu_config=None, torch_d
112
  )
113
 
114
  except Exception as e:
115
- print("GPU load failed")
116
- print(e)
117
  model = Llama(
118
  model_path=hf_hub_download(
119
  repo_id=os.environ.get("REPO_ID", "QuantFactory/Phi-3.5-mini-instruct-GGUF"), #"QuantFactory/Phi-3-mini-128k-instruct-GGUF"), #, "microsoft/Phi-3-mini-4k-instruct-gguf"),#"QuantFactory/Meta-Llama-3-8B-Instruct-GGUF-v2"), #"microsoft/Phi-3-mini-4k-instruct-gguf"),#"TheBloke/Mistral-7B-OpenOrca-GGUF"),
@@ -128,57 +141,21 @@ def load_model(model_type, gpu_layers, gpu_config=None, cpu_config=None, torch_d
128
  # Huggingface chat model
129
  hf_checkpoint = 'Qwen/Qwen2-0.5B-Instruct'# 'declare-lab/flan-alpaca-large'#'declare-lab/flan-alpaca-base' # # # 'Qwen/Qwen1.5-0.5B-Chat' #
130
 
131
- def create_hf_model(model_name):
132
-
133
- from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM
134
-
135
- if torch_device == "cuda":
136
- if "flan" in model_name:
137
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")#, torch_dtype=torch.float16)
138
- else:
139
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")#, torch_dtype=torch.float16)
140
- else:
141
- if "flan" in model_name:
142
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)#, torch_dtype=torch.float16)
143
- else:
144
- model = AutoModelForCausalLM.from_pretrained(model_name)#, trust_remote_code=True)#, torch_dtype=torch.float16)
145
-
146
- tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length = chatf.context_length)
147
 
148
- return model, tokenizer, model_type
149
-
150
- model, tokenizer, model_type = create_hf_model(model_name = hf_checkpoint)
151
 
152
- chatf.model = model
153
  chatf.tokenizer = tokenizer
154
  chatf.model_type = model_type
155
 
156
  load_confirmation = "Finished loading model: " + model_type
157
 
158
  print(load_confirmation)
159
- return model_type, load_confirmation, model_type
160
 
161
- # Both models are loaded on app initialisation so that users don't have to wait for the models to be downloaded
162
- #model_type = "Phi 3.5 Mini (larger, slow)"
163
- #load_model(model_type, chatf.gpu_layers, chatf.gpu_config, chatf.cpu_config, chatf.torch_device)
164
-
165
- model_type = "Qwen 2 0.5B (small, fast)"
166
- load_model(model_type, 0, chatf.gpu_config, chatf.cpu_config, chatf.torch_device)
167
-
168
- def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
169
-
170
- print(f"> Total split documents: {len(docs_out)}")
171
-
172
- print(docs_out)
173
-
174
- vectorstore_func = FAISS.from_documents(documents=docs_out, embedding=embeddings)
175
-
176
- chatf.vectorstore = vectorstore_func
177
-
178
- out_message = "Document processing complete"
179
-
180
- return out_message, vectorstore_func
181
- # Gradio chat
182
 
183
 
184
  ###
@@ -188,17 +165,29 @@ def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
188
  app = gr.Blocks(theme = gr.themes.Base(), fill_width=True)#css=".gradio-container {background-color: black}")
189
 
190
  with app:
 
 
 
 
 
 
 
 
 
191
  ingest_text = gr.State()
192
  ingest_metadata = gr.State()
193
  ingest_docs = gr.State()
194
 
195
  model_type_state = gr.State(model_type)
 
 
 
196
  embeddings_state = gr.State(chatf.embeddings)#globals()["embeddings"])
197
  vectorstore_state = gr.State(chatf.vectorstore)#globals()["vectorstore"])
198
 
199
  relevant_query_state = gr.Checkbox(value=True, visible=False)
200
 
201
- model_state = gr.State() # chatf.model (gives error)
202
  tokenizer_state = gr.State() # chatf.tokenizer (gives error)
203
 
204
  chat_history_state = gr.State()
@@ -222,7 +211,7 @@ with app:
222
  gr.Markdown("Chat with PDF, web page or (new) csv/Excel documents. The default is a small model (Qwen 2 0.5B), that can only answer specific questions that are answered in the text. It cannot give overall impressions of, or summarise the document. The alternative (Phi 3.5 Mini (larger, slow)), can reason a little better, but is much slower (See Advanced tab).\n\nBy default the Lambeth Borough Plan '[Lambeth 2030 : Our Future, Our Lambeth](https://www.lambeth.gov.uk/better-fairer-lambeth/projects/lambeth-2030-our-future-our-lambeth)' is loaded. If you want to talk about another document or web page, please select from the second tab. If switching topic, please click the 'Clear chat' button.\n\nCaution: This is a public app. Please ensure that the document you upload is not sensitive is any way as other users may see it! Also, please note that LLM chatbots may give incomplete or incorrect information, so please use with care.")
223
 
224
  with gr.Accordion(label="Use Gemini or AWS Claude model", open=False, visible=False):
225
- api_model_choice = gr.Dropdown(value = "None", choices = ["gemini-1.5-flash-002", "gemini-1.5-pro-002", "anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0", "None"], label="LLM model to use", multiselect=False, interactive=True, visible=False)
226
  in_api_key = gr.Textbox(value = "", label="Enter Gemini API key (only if using Google API models)", lines=1, type="password",interactive=True, visible=False)
227
 
228
  with gr.Row():
@@ -233,7 +222,7 @@ with app:
233
 
234
  with gr.Row():
235
  #chat_height = 500
236
- chatbot = gr.Chatbot(avatar_images=('user.jfif', 'bot.jpg'), scale = 1, resizable=True, show_copy_all_button=True, show_copy_button=True, show_share_button=True, type='tuples') # , height=chat_height
237
  with gr.Accordion("Open this tab to see the source paragraphs used to generate the answer", open = True):
238
  sources = gr.HTML(value = "Source paragraphs with the most relevant text will appear here") # , height=chat_height
239
 
@@ -281,13 +270,12 @@ with app:
281
  out_passages = gr.Slider(minimum=1, value = 2, maximum=10, step=1, label="Choose number of passages to retrieve from the document. Numbers greater than 2 may lead to increased hallucinations or input text being truncated.")
282
  temp_slide = gr.Slider(minimum=0.1, value = 0.5, maximum=1, step=0.1, label="Choose temperature setting for response generation.")
283
  with gr.Row():
284
- model_choice = gr.Radio(label="Choose a chat model", value="Qwen 2 0.5B (small, fast)", choices = ["Qwen 2 0.5B (small, fast)", "Phi 3.5 Mini (larger, slow)"])
285
  change_model_button = gr.Button(value="Load model", scale=0)
286
  with gr.Accordion("Choose number of model layers to send to GPU (WARNING: please don't modify unless you are sure you have a GPU).", open = False):
287
  gpu_layer_choice = gr.Slider(label="Choose number of model layers to send to GPU.", value=0, minimum=0, maximum=100, step = 1, visible=True)
288
 
289
- load_text = gr.Text(label="Load status")
290
-
291
 
292
  gr.HTML(
293
  "<center>This app is based on the models Qwen 2 0.5B and Phi 3.5 Mini. It powered by Gradio, Transformers, and Llama.cpp.</a></center>"
@@ -295,11 +283,41 @@ with app:
295
 
296
  examples_set.change(fn=chatf.update_message, inputs=[examples_set], outputs=[message])
297
 
298
- change_model_button.click(fn=chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
299
- success(fn=load_model, inputs=[model_choice, gpu_layer_choice], outputs = [model_type_state, load_text, current_model]).\
300
- success(lambda: chatf.restore_interactivity(), None, [message], queue=False).\
301
- success(chatf.clear_chat, inputs=[chat_history_state, sources, message, current_topic], outputs=[chat_history_state, sources, message, current_topic]).\
302
- success(lambda: None, None, chatbot, queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  # Load in a pdf
305
  load_pdf_click = load_pdf.click(ing.parse_file, inputs=[in_pdf], outputs=[ingest_text, current_source]).\
@@ -318,38 +336,23 @@ with app:
318
  success(ing.csv_excel_text_to_docs, inputs=[ingest_text, in_text_column], outputs=[ingest_docs]).\
319
  success(embed_faiss_save_to_zip, inputs=[ingest_docs], outputs=[ingest_embed_out, vectorstore_state, file_out_box]).\
320
  success(chatf.hide_block, outputs = [examples_set])
 
321
 
322
- # Load in a webpage
323
-
324
- # Click/enter to send message action
325
- response_click = submit.click(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages, api_model_choice, in_api_key], outputs=[chat_history_state, sources, instruction_prompt_out, relevant_query_state], queue=False, api_name="retrieval").\
326
- success(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
327
- success(chatf.produce_streaming_answer_chatbot, inputs=[chatbot, instruction_prompt_out, model_type_state, temp_slide, relevant_query_state], outputs=chatbot)
328
- response_click.success(chatf.highlight_found_text, [chatbot, sources], [sources]).\
329
- success(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
330
- success(lambda: chatf.restore_interactivity(), None, [message], queue=False)
331
-
332
- response_enter = message.submit(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages, api_model_choice, in_api_key], outputs=[chat_history_state, sources, instruction_prompt_out, relevant_query_state], queue=False).\
333
- success(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
334
- success(chatf.produce_streaming_answer_chatbot, [chatbot, instruction_prompt_out, model_type_state, temp_slide, relevant_query_state], chatbot)
335
- response_enter.success(chatf.highlight_found_text, [chatbot, sources], [sources]).\
336
- success(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
337
- success(lambda: chatf.restore_interactivity(), None, [message], queue=False)
338
-
339
- # Stop box
340
- stop.click(fn=None, inputs=None, outputs=None, cancels=[response_click, response_enter])
341
-
342
- # Clear box
343
- clear.click(chatf.clear_chat, inputs=[chat_history_state, sources, message, current_topic], outputs=[chat_history_state, sources, message, current_topic])
344
- clear.click(lambda: None, None, chatbot, queue=False)
345
 
346
- # Thumbs up or thumbs down voting function
347
- chatbot.like(chatf.vote, [chat_history_state, instruction_prompt_out, model_type_state], None)
 
 
 
348
 
349
  ###
350
  # LOGGING AND ON APP LOAD FUNCTIONS
351
  ###
352
- app.load(get_connection_params, inputs=None, outputs=[session_hash_state, s3_output_folder_state, session_hash_textbox])
 
353
 
354
  # Log usernames and times of access to file (to know who is using the app when running on AWS)
355
  access_callback = gr.CSVLogger()
@@ -358,10 +361,6 @@ with app:
358
  session_hash_textbox.change(lambda *args: access_callback.flag(list(args)), [session_hash_textbox], None, preprocess=False).\
359
  success(fn = upload_file_to_s3, inputs=[access_logs_state, access_s3_logs_loc_state], outputs=[s3_logs_output_textbox])
360
 
361
- # Launch the Gradio app
362
- COGNITO_AUTH = get_or_create_env_var('COGNITO_AUTH', '0')
363
- print(f'The value of COGNITO_AUTH is {COGNITO_AUTH}')
364
-
365
  if __name__ == "__main__":
366
  if os.environ['COGNITO_AUTH'] == "1":
367
  app.queue().launch(show_error=True, auth=authenticate_user, max_file_size='50mb')
 
 
 
1
  import os
 
 
2
  from typing import Type
3
+ from langchain_huggingface.embeddings import HuggingFaceEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
  import gradio as gr
6
  import pandas as pd
7
 
 
 
 
 
 
8
  from chatfuncs.ingest import embed_faiss_save_to_zip
 
9
 
10
+ from chatfuncs.helper_functions import ensure_output_folder_exists, get_connection_params, output_folder, reveal_feedback_buttons, wipe_logs
11
  from chatfuncs.aws_functions import upload_file_to_s3
 
12
  from chatfuncs.auth import authenticate_user
13
+ from chatfuncs.config import FEEDBACK_LOGS_FOLDER, ACCESS_LOGS_FOLDER, USAGE_LOGS_FOLDER, HOST_NAME, COGNITO_AUTH
14
+
15
+ from llama_cpp import Llama
16
+ from huggingface_hub import hf_hub_download
17
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
18
+ import os
19
 
20
  PandasDataFrame = Type[pd.DataFrame]
21
 
22
  from datetime import datetime
23
  today_rev = datetime.now().strftime("%Y%m%d")
24
 
25
+ host_name = HOST_NAME
26
+ access_logs_data_folder = ACCESS_LOGS_FOLDER
27
+ feedback_data_folder = FEEDBACK_LOGS_FOLDER
28
+ usage_data_folder = USAGE_LOGS_FOLDER
 
 
 
29
 
30
  # Disable cuda devices if necessary
31
  #os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
32
 
 
33
  import chatfuncs.ingest as ing
34
 
35
  ###
 
64
  return vectorstore
65
 
66
  import chatfuncs.chatfuncs as chatf
67
+ from chatfuncs.model_load import torch_device, gpu_config, cpu_config, context_length
68
 
69
  chatf.embeddings = load_embeddings(embeddings_name)
70
  chatf.vectorstore = get_faiss_store(faiss_vstore_folder="faiss_embedding",embeddings=globals()["embeddings"])
71
 
72
+ def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
73
 
74
+ print(f"> Total split documents: {len(docs_out)}")
75
+
76
+ print(docs_out)
77
+
78
+ vectorstore_func = FAISS.from_documents(documents=docs_out, embedding=embeddings)
79
+
80
+ chatf.vectorstore = vectorstore_func
81
+
82
+ out_message = "Document processing complete"
83
+
84
+ return out_message, vectorstore_func
85
+ # Gradio chat
86
+
87
+ def create_hf_model(model_name:str):
88
+ if torch_device == "cuda":
89
+ if "flan" in model_name:
90
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")#, torch_dtype=torch.float16)
91
+ else:
92
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")#, torch_dtype=torch.float16)
93
+ else:
94
+ if "flan" in model_name:
95
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)#, torch_dtype=torch.float16)
96
+ else:
97
+ model = AutoModelForCausalLM.from_pretrained(model_name)#, trust_remote_code=True)#, torch_dtype=torch.float16)
98
+
99
+ tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length = context_length)
100
 
101
+ return model, tokenizer
102
+
103
+ def load_model(model_type:str, gpu_layers:int, gpu_config:dict=gpu_config, cpu_config:dict=cpu_config, torch_device:str=torch_device):
104
+ print("Loading model")
 
 
 
105
 
106
  if model_type == "Phi 3.5 Mini (larger, slow)":
107
  if torch_device == "cuda":
 
126
  )
127
 
128
  except Exception as e:
129
+ print("GPU load failed", e)
 
130
  model = Llama(
131
  model_path=hf_hub_download(
132
  repo_id=os.environ.get("REPO_ID", "QuantFactory/Phi-3.5-mini-instruct-GGUF"), #"QuantFactory/Phi-3-mini-128k-instruct-GGUF"), #, "microsoft/Phi-3-mini-4k-instruct-gguf"),#"QuantFactory/Meta-Llama-3-8B-Instruct-GGUF-v2"), #"microsoft/Phi-3-mini-4k-instruct-gguf"),#"TheBloke/Mistral-7B-OpenOrca-GGUF"),
 
141
  # Huggingface chat model
142
  hf_checkpoint = 'Qwen/Qwen2-0.5B-Instruct'# 'declare-lab/flan-alpaca-large'#'declare-lab/flan-alpaca-base' # # # 'Qwen/Qwen1.5-0.5B-Chat' #
143
 
144
+ model, tokenizer = create_hf_model(model_name = hf_checkpoint)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ else:
147
+ model = model_type
148
+ tokenizer = ""
149
 
150
+ chatf.model_object = model
151
  chatf.tokenizer = tokenizer
152
  chatf.model_type = model_type
153
 
154
  load_confirmation = "Finished loading model: " + model_type
155
 
156
  print(load_confirmation)
 
157
 
158
+ return model_type, load_confirmation, model_type#model, tokenizer, model_type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
 
161
  ###
 
165
  app = gr.Blocks(theme = gr.themes.Base(), fill_width=True)#css=".gradio-container {background-color: black}")
166
 
167
  with app:
168
+ model_type = "Qwen 2 0.5B (small, fast)"
169
+ load_model(model_type, 0, gpu_config, cpu_config, torch_device) # chatf.model_object, chatf.tokenizer, chatf.model_type =
170
+
171
+ print("chatf.model_object:", chatf.model_object)
172
+
173
+ # Both models are loaded on app initialisation so that users don't have to wait for the models to be downloaded
174
+ #model_type = "Phi 3.5 Mini (larger, slow)"
175
+ #load_model(model_type, gpu_layers, gpu_config, cpu_config, torch_device)
176
+
177
  ingest_text = gr.State()
178
  ingest_metadata = gr.State()
179
  ingest_docs = gr.State()
180
 
181
  model_type_state = gr.State(model_type)
182
+ gpu_config_state = gr.State(gpu_config)
183
+ cpu_config_state = gr.State(cpu_config)
184
+ torch_device_state = gr.State(torch_device)
185
  embeddings_state = gr.State(chatf.embeddings)#globals()["embeddings"])
186
  vectorstore_state = gr.State(chatf.vectorstore)#globals()["vectorstore"])
187
 
188
  relevant_query_state = gr.Checkbox(value=True, visible=False)
189
 
190
+ model_state = gr.State() # chatf.model_object (gives error)
191
  tokenizer_state = gr.State() # chatf.tokenizer (gives error)
192
 
193
  chat_history_state = gr.State()
 
211
  gr.Markdown("Chat with PDF, web page or (new) csv/Excel documents. The default is a small model (Qwen 2 0.5B), that can only answer specific questions that are answered in the text. It cannot give overall impressions of, or summarise the document. The alternative (Phi 3.5 Mini (larger, slow)), can reason a little better, but is much slower (See Advanced tab).\n\nBy default the Lambeth Borough Plan '[Lambeth 2030 : Our Future, Our Lambeth](https://www.lambeth.gov.uk/better-fairer-lambeth/projects/lambeth-2030-our-future-our-lambeth)' is loaded. If you want to talk about another document or web page, please select from the second tab. If switching topic, please click the 'Clear chat' button.\n\nCaution: This is a public app. Please ensure that the document you upload is not sensitive is any way as other users may see it! Also, please note that LLM chatbots may give incomplete or incorrect information, so please use with care.")
212
 
213
  with gr.Accordion(label="Use Gemini or AWS Claude model", open=False, visible=False):
214
+ api_model_choice = gr.Dropdown(value = "None", choices = ["gemini-2.0-flash-001", "gemini-2.5-flash-preview-04-17", "gemini-2.5-pro-preview-03-25", "anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0", "None"], label="LLM model to use", multiselect=False, interactive=True, visible=False)
215
  in_api_key = gr.Textbox(value = "", label="Enter Gemini API key (only if using Google API models)", lines=1, type="password",interactive=True, visible=False)
216
 
217
  with gr.Row():
 
222
 
223
  with gr.Row():
224
  #chat_height = 500
225
+ chatbot = gr.Chatbot(value=None, avatar_images=('user.jfif', 'bot.jpg'), scale = 1, resizable=True, show_copy_all_button=True, show_copy_button=True, show_share_button=True, type='messages') # , height=chat_height
226
  with gr.Accordion("Open this tab to see the source paragraphs used to generate the answer", open = True):
227
  sources = gr.HTML(value = "Source paragraphs with the most relevant text will appear here") # , height=chat_height
228
 
 
270
  out_passages = gr.Slider(minimum=1, value = 2, maximum=10, step=1, label="Choose number of passages to retrieve from the document. Numbers greater than 2 may lead to increased hallucinations or input text being truncated.")
271
  temp_slide = gr.Slider(minimum=0.1, value = 0.5, maximum=1, step=0.1, label="Choose temperature setting for response generation.")
272
  with gr.Row():
273
+ model_choice = gr.Radio(label="Choose a chat model", value="Qwen 2 0.5B (small, fast)", choices = ["Qwen 2 0.5B (small, fast)", "Phi 3.5 Mini (larger, slow)", "gemini-2.0-flash-001", "gemini-2.5-flash-preview-04-17", "gemini-2.5-pro-preview-03-25", "anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"])
274
  change_model_button = gr.Button(value="Load model", scale=0)
275
  with gr.Accordion("Choose number of model layers to send to GPU (WARNING: please don't modify unless you are sure you have a GPU).", open = False):
276
  gpu_layer_choice = gr.Slider(label="Choose number of model layers to send to GPU.", value=0, minimum=0, maximum=100, step = 1, visible=True)
277
 
278
+ load_text = gr.Text(label="Load status")
 
279
 
280
  gr.HTML(
281
  "<center>This app is based on the models Qwen 2 0.5B and Phi 3.5 Mini. It powered by Gradio, Transformers, and Llama.cpp.</a></center>"
 
283
 
284
  examples_set.change(fn=chatf.update_message, inputs=[examples_set], outputs=[message])
285
 
286
+
287
+ ###
288
+ # CHAT PAGE
289
+ ###
290
+
291
+ # Click to send message
292
+ response_click = submit.click(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages, api_model_choice, in_api_key], outputs=[chat_history_state, sources, instruction_prompt_out, relevant_query_state], queue=False, api_name="retrieval").\
293
+ success(chatf.turn_off_interactivity, inputs=None, outputs=[message, submit], queue=False).\
294
+ success(chatf.produce_streaming_answer_chatbot, inputs=[chatbot, instruction_prompt_out, model_type_state, temp_slide, relevant_query_state, chat_history_state], outputs=chatbot)
295
+ response_click.success(chatf.highlight_found_text, [chatbot, sources], [sources]).\
296
+ success(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
297
+ success(lambda: chatf.restore_interactivity(), None, [message, submit], queue=False)
298
+
299
+ # Press enter to send message
300
+ response_enter = message.submit(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages, api_model_choice, in_api_key], outputs=[chat_history_state, sources, instruction_prompt_out, relevant_query_state], queue=False).\
301
+ success(chatf.turn_off_interactivity, inputs=None, outputs=[message, submit], queue=False).\
302
+ success(chatf.produce_streaming_answer_chatbot, [chatbot, instruction_prompt_out, model_type_state, temp_slide, relevant_query_state, chat_history_state], chatbot)
303
+ response_enter.success(chatf.highlight_found_text, [chatbot, sources], [sources]).\
304
+ success(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
305
+ success(lambda: chatf.restore_interactivity(), None, [message, submit], queue=False)
306
+
307
+ # Stop box
308
+ stop.click(fn=None, inputs=None, outputs=None, cancels=[response_click, response_enter])
309
+
310
+ # Clear box
311
+ clear.click(chatf.clear_chat, inputs=[chat_history_state, sources, message, current_topic], outputs=[chat_history_state, sources, message, current_topic])
312
+ clear.click(lambda: None, None, chatbot, queue=False)
313
+
314
+ # Thumbs up or thumbs down voting function
315
+ chatbot.like(chatf.vote, [chat_history_state, instruction_prompt_out, model_type_state], None)
316
+
317
+
318
+ ###
319
+ # LOAD NEW DATA PAGE
320
+ ###
321
 
322
  # Load in a pdf
323
  load_pdf_click = load_pdf.click(ing.parse_file, inputs=[in_pdf], outputs=[ingest_text, current_source]).\
 
336
  success(ing.csv_excel_text_to_docs, inputs=[ingest_text, in_text_column], outputs=[ingest_docs]).\
337
  success(embed_faiss_save_to_zip, inputs=[ingest_docs], outputs=[ingest_embed_out, vectorstore_state, file_out_box]).\
338
  success(chatf.hide_block, outputs = [examples_set])
339
+
340
 
341
+ ###
342
+ # LOAD MODEL PAGE
343
+ ###
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
+ change_model_button.click(fn=chatf.turn_off_interactivity, inputs=None, outputs=[message, submit], queue=False).\
346
+ success(fn=load_model, inputs=[model_choice, gpu_layer_choice], outputs = [model_type_state, load_text, current_model]).\
347
+ success(lambda: chatf.restore_interactivity(), None, [message, submit], queue=False).\
348
+ success(chatf.clear_chat, inputs=[chat_history_state, sources, message, current_topic], outputs=[chat_history_state, sources, message, current_topic]).\
349
+ success(lambda: None, None, chatbot, queue=False)
350
 
351
  ###
352
  # LOGGING AND ON APP LOAD FUNCTIONS
353
  ###
354
+ app.load(get_connection_params, inputs=None, outputs=[session_hash_state, s3_output_folder_state, session_hash_textbox]).\
355
+ success(load_model, inputs=[model_type_state, gpu_layer_choice, gpu_config_state, cpu_config_state, torch_device_state], outputs=[model_type_state, load_text, current_model])
356
 
357
  # Log usernames and times of access to file (to know who is using the app when running on AWS)
358
  access_callback = gr.CSVLogger()
 
361
  session_hash_textbox.change(lambda *args: access_callback.flag(list(args)), [session_hash_textbox], None, preprocess=False).\
362
  success(fn = upload_file_to_s3, inputs=[access_logs_state, access_s3_logs_loc_state], outputs=[s3_logs_output_textbox])
363
 
 
 
 
 
364
  if __name__ == "__main__":
365
  if os.environ['COGNITO_AUTH'] == "1":
366
  app.queue().launch(show_error=True, auth=authenticate_user, max_file_size='50mb')
chatfuncs/auth.py CHANGED
@@ -1,14 +1,22 @@
1
-
2
  import boto3
3
- from chatfuncs.helper_functions import get_or_create_env_var
4
-
5
- client_id = get_or_create_env_var('AWS_CLIENT_ID', '') # This client id is borrowed from async gradio app client
6
- print(f'The value of AWS_CLIENT_ID is {client_id}')
 
7
 
8
- user_pool_id = get_or_create_env_var('AWS_USER_POOL_ID', '')
9
- print(f'The value of AWS_USER_POOL_ID is {user_pool_id}')
 
 
 
 
 
 
 
10
 
11
- def authenticate_user(username, password, user_pool_id=user_pool_id, client_id=client_id):
12
  """Authenticates a user against an AWS Cognito user pool.
13
 
14
  Args:
@@ -16,22 +24,39 @@ def authenticate_user(username, password, user_pool_id=user_pool_id, client_id=c
16
  client_id (str): The ID of the Cognito user pool client.
17
  username (str): The username of the user.
18
  password (str): The password of the user.
 
19
 
20
  Returns:
21
  bool: True if the user is authenticated, False otherwise.
22
  """
23
 
24
- client = boto3.client('cognito-idp') # Cognito Identity Provider client
 
 
 
25
 
26
  try:
27
- response = client.initiate_auth(
 
 
 
 
 
 
 
 
 
 
 
 
28
  AuthFlow='USER_PASSWORD_AUTH',
29
  AuthParameters={
30
  'USERNAME': username,
31
  'PASSWORD': password,
 
32
  },
33
  ClientId=client_id
34
- )
35
 
36
  # If successful, you'll receive an AuthenticationResult in the response
37
  if response.get('AuthenticationResult'):
@@ -44,5 +69,7 @@ def authenticate_user(username, password, user_pool_id=user_pool_id, client_id=c
44
  except client.exceptions.UserNotFoundException:
45
  return False
46
  except Exception as e:
47
- print(f"An error occurred: {e}")
48
- return False
 
 
 
1
+ #import os
2
  import boto3
3
+ #import gradio as gr
4
+ import hmac
5
+ import hashlib
6
+ import base64
7
+ from chatfuncs.config import AWS_CLIENT_ID, AWS_CLIENT_SECRET, AWS_USER_POOL_ID, AWS_REGION
8
 
9
+ def calculate_secret_hash(client_id:str, client_secret:str, username:str):
10
+ message = username + client_id
11
+ dig = hmac.new(
12
+ str(client_secret).encode('utf-8'),
13
+ msg=str(message).encode('utf-8'),
14
+ digestmod=hashlib.sha256
15
+ ).digest()
16
+ secret_hash = base64.b64encode(dig).decode()
17
+ return secret_hash
18
 
19
+ def authenticate_user(username:str, password:str, user_pool_id:str=AWS_USER_POOL_ID, client_id:str=AWS_CLIENT_ID, client_secret:str=AWS_CLIENT_SECRET):
20
  """Authenticates a user against an AWS Cognito user pool.
21
 
22
  Args:
 
24
  client_id (str): The ID of the Cognito user pool client.
25
  username (str): The username of the user.
26
  password (str): The password of the user.
27
+ client_secret (str): The client secret of the app client
28
 
29
  Returns:
30
  bool: True if the user is authenticated, False otherwise.
31
  """
32
 
33
+ client = boto3.client('cognito-idp', region_name=AWS_REGION) # Cognito Identity Provider client
34
+
35
+ # Compute the secret hash
36
+ secret_hash = calculate_secret_hash(client_id, client_secret, username)
37
 
38
  try:
39
+
40
+ if client_secret == '':
41
+ response = client.initiate_auth(
42
+ AuthFlow='USER_PASSWORD_AUTH',
43
+ AuthParameters={
44
+ 'USERNAME': username,
45
+ 'PASSWORD': password,
46
+ },
47
+ ClientId=client_id
48
+ )
49
+
50
+ else:
51
+ response = client.initiate_auth(
52
  AuthFlow='USER_PASSWORD_AUTH',
53
  AuthParameters={
54
  'USERNAME': username,
55
  'PASSWORD': password,
56
+ 'SECRET_HASH': secret_hash
57
  },
58
  ClientId=client_id
59
+ )
60
 
61
  # If successful, you'll receive an AuthenticationResult in the response
62
  if response.get('AuthenticationResult'):
 
69
  except client.exceptions.UserNotFoundException:
70
  return False
71
  except Exception as e:
72
+ out_message = f"An error occurred: {e}"
73
+ print(out_message)
74
+ raise Exception(out_message)
75
+ return False
chatfuncs/chatfuncs.py CHANGED
@@ -6,11 +6,16 @@ import time
6
  from itertools import compress
7
  import pandas as pd
8
  import numpy as np
 
 
 
 
9
 
10
  # Model packages
11
  import torch.cuda
12
  from threading import Thread
13
  from transformers import pipeline, TextIteratorStreamer
 
14
 
15
  # Alternative model sources
16
  #from dataclasses import asdict, dataclass
@@ -22,31 +27,37 @@ from langchain_community.retrievers import SVMRetriever
22
  from langchain.text_splitter import RecursiveCharacterTextSplitter
23
  from langchain.docstore.document import Document
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # For keyword extraction (not currently used)
26
  #import nltk
27
  #nltk.download('wordnet')
28
  from nltk.corpus import stopwords
29
  from nltk.tokenize import RegexpTokenizer
30
  from nltk.stem import WordNetLemmatizer
31
- #from nltk.stem.snowball import SnowballStemmer
32
  from keybert import KeyBERT
33
 
34
  # For Name Entity Recognition model
35
  #from span_marker import SpanMarkerModel # Not currently used
36
 
37
-
38
  # For BM25 retrieval
39
  import bm25s
40
  import Stemmer
41
 
42
- #from gensim.corpora import Dictionary
43
- #from gensim.models import TfidfModel, OkapiBM25Model
44
- #from gensim.similarities import SparseMatrixSimilarity
45
-
46
- from llama_cpp import Llama
47
- from huggingface_hub import hf_hub_download
48
-
49
- from chatfuncs.prompts import instruction_prompt_template_alpaca, instruction_prompt_mistral_orca, instruction_prompt_phi3, instruction_prompt_llama3, instruction_prompt_qwen
50
 
51
  import gradio as gr
52
 
@@ -60,10 +71,8 @@ model_type = None # global variable setup
60
 
61
  max_memory_length = 0 # How long should the memory of the conversation last?
62
 
63
- full_text = "" # Define dummy source text (full text) just to enable highlight function to load
64
 
65
- model = [] # Define empty list for model functions to run
66
- tokenizer = [] # Define empty list for model functions to run
67
 
68
  ## Highlight text constants
69
  hlt_chunk_size = 12
@@ -77,84 +86,6 @@ ner_model = []#SpanMarkerModel.from_pretrained("tomaarsen/span-marker-mbert-base
77
  # Used to pull out keywords from chat history to add to user queries behind the scenes
78
  kw_model = pipeline("feature-extraction", model="sentence-transformers/all-MiniLM-L6-v2")
79
 
80
- # Currently set gpu_layers to 0 even with cuda due to persistent bugs in implementation with cuda
81
- if torch.cuda.is_available():
82
- torch_device = "cuda"
83
- gpu_layers = 100
84
- else:
85
- torch_device = "cpu"
86
- gpu_layers = 0
87
-
88
- print("Running on device:", torch_device)
89
- threads = 8 #torch.get_num_threads()
90
- print("CPU threads:", threads)
91
-
92
- # Qwen 2 0.5B (small, fast) Model parameters
93
- temperature: float = 0.1
94
- top_k: int = 3
95
- top_p: float = 1
96
- repetition_penalty: float = 1.15
97
- flan_alpaca_repetition_penalty: float = 1.3
98
- last_n_tokens: int = 64
99
- max_new_tokens: int = 1024
100
- seed: int = 42
101
- reset: bool = False
102
- stream: bool = True
103
- threads: int = threads
104
- batch_size:int = 256
105
- context_length:int = 2048
106
- sample = True
107
-
108
-
109
- class CtransInitConfig_gpu:
110
- def __init__(self,
111
- last_n_tokens=last_n_tokens,
112
- seed=seed,
113
- n_threads=threads,
114
- n_batch=batch_size,
115
- n_ctx=4096,
116
- n_gpu_layers=gpu_layers):
117
-
118
- self.last_n_tokens = last_n_tokens
119
- self.seed = seed
120
- self.n_threads = n_threads
121
- self.n_batch = n_batch
122
- self.n_ctx = n_ctx
123
- self.n_gpu_layers = n_gpu_layers
124
- # self.stop: list[str] = field(default_factory=lambda: [stop_string])
125
-
126
- def update_gpu(self, new_value):
127
- self.n_gpu_layers = new_value
128
-
129
- class CtransInitConfig_cpu(CtransInitConfig_gpu):
130
- def __init__(self):
131
- super().__init__()
132
- self.n_gpu_layers = 0
133
-
134
- gpu_config = CtransInitConfig_gpu()
135
- cpu_config = CtransInitConfig_cpu()
136
-
137
-
138
- class CtransGenGenerationConfig:
139
- def __init__(self, temperature=temperature,
140
- top_k=top_k,
141
- top_p=top_p,
142
- repeat_penalty=repetition_penalty,
143
- seed=seed,
144
- stream=stream,
145
- max_tokens=max_new_tokens
146
- ):
147
- self.temperature = temperature
148
- self.top_k = top_k
149
- self.top_p = top_p
150
- self.repeat_penalty = repeat_penalty
151
- self.seed = seed
152
- self.max_tokens=max_tokens
153
- self.stream = stream
154
-
155
- def update_temp(self, new_value):
156
- self.temperature = new_value
157
-
158
  # Vectorstore funcs
159
 
160
  def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
@@ -187,7 +118,7 @@ def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
187
 
188
  # Prompt functions
189
 
190
- def base_prompt_templates(model_type = "Qwen 2 0.5B (small, fast)"):
191
 
192
  #EXAMPLE_PROMPT = PromptTemplate(
193
  # template="\nCONTENT:\n\n{page_content}\n\nSOURCE: {source}\n\n",
@@ -205,20 +136,20 @@ def base_prompt_templates(model_type = "Qwen 2 0.5B (small, fast)"):
205
  INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_qwen, input_variables=['question', 'summaries'])
206
  elif model_type == "Phi 3.5 Mini (larger, slow)":
207
  INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_phi3, input_variables=['question', 'summaries'])
 
 
 
208
 
209
  return INSTRUCTION_PROMPT, CONTENT_PROMPT
210
 
211
- def write_out_metadata_as_string(metadata_in):
212
  metadata_string = [f"{' '.join(f'{k}: {v}' for k, v in d.items() if k != 'page_section')}" for d in metadata_in] # ['metadata']
213
  return metadata_string
214
 
215
- def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt, content_prompt, extracted_memory, vectorstore, embeddings, relevant_flag = True, out_passages = 2): # ,
216
 
217
  question = inputs["question"]
218
  chat_history = inputs["chat_history"]
219
-
220
- print("relevant_flag in generate_expanded_prompt:", relevant_flag)
221
-
222
 
223
  if relevant_flag == True:
224
  new_question_kworded = adapt_q_from_chat_history(question, chat_history, extracted_memory) # new_question_keywords,
@@ -234,8 +165,6 @@ def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt, content
234
  return sorry_prompt, "No relevant sources found.", new_question_kworded
235
 
236
  # Expand the found passages to the neighbouring context
237
- print("Doc_df columns:", doc_df.columns)
238
-
239
  if 'meta_url' in doc_df.columns:
240
  file_type = determine_file_type(doc_df['meta_url'][0])
241
  else:
@@ -265,14 +194,22 @@ def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt, content
265
 
266
  return instruction_prompt_out, sources_docs_content_string, new_question_kworded
267
 
268
- def create_full_prompt(user_input, history, extracted_memory, vectorstore, embeddings, model_type, out_passages, api_model_choice=None, api_key=None, relevant_flag = True):
 
 
 
 
 
 
 
 
 
269
 
270
  #if chain_agent is None:
271
  # history.append((user_input, "Please click the button to submit the Huggingface API key before using the chatbot (top right)"))
272
  # return history, history, "", ""
273
  print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
274
-
275
-
276
  history = history or []
277
 
278
  if api_model_choice and api_model_choice != "None":
@@ -298,31 +235,13 @@ def create_full_prompt(user_input, history, extracted_memory, vectorstore, embed
298
  generate_expanded_prompt({"question": user_input, "chat_history": history}, #vectorstore,
299
  instruction_prompt, content_prompt, extracted_memory, vectorstore, embeddings, relevant_flag, out_passages)
300
 
301
- history.append(user_input)
302
 
303
  print("Output history is:", history)
304
  print("Final prompt to model is:",instruction_prompt_out)
305
 
306
  return history, docs_content_string, instruction_prompt_out, relevant_flag
307
 
308
- # Chat functions
309
- import boto3
310
- import json
311
- from chatfuncs.helper_functions import get_or_create_env_var
312
-
313
- # ResponseObject class for AWS Bedrock calls
314
- class ResponseObject:
315
- def __init__(self, text, usage_metadata):
316
- self.text = text
317
- self.usage_metadata = usage_metadata
318
-
319
- max_tokens = 4096
320
-
321
- AWS_DEFAULT_REGION = get_or_create_env_var('AWS_DEFAULT_REGION', 'eu-west-2')
322
- print(f'The value of AWS_DEFAULT_REGION is {AWS_DEFAULT_REGION}')
323
-
324
- bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_DEFAULT_REGION)
325
-
326
  def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tokens: int, model_choice: str) -> ResponseObject:
327
  """
328
  This function sends a request to AWS Claude with the following parameters:
@@ -351,6 +270,8 @@ def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tok
351
  ],
352
  }
353
 
 
 
354
  body = json.dumps(prompt_config)
355
 
356
  modelId = model_choice
@@ -376,16 +297,173 @@ def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tok
376
 
377
  return response
378
 
379
- def produce_streaming_answer_chatbot(history,
380
- full_prompt,
381
- model_type,
382
- temperature=temperature,
383
- relevant_query_bool=True,
384
- max_new_tokens=max_new_tokens,
385
- sample=sample,
386
- repetition_penalty=repetition_penalty,
387
- top_p=top_p,
388
- top_k=top_k
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  ):
390
  #print("Model type is: ", model_type)
391
 
@@ -395,16 +473,19 @@ def produce_streaming_answer_chatbot(history,
395
 
396
  # return history
397
 
398
-
 
 
399
 
400
  if relevant_query_bool == False:
401
- out_message = [("","No relevant query found. Please retry your question")]
402
- history.append(out_message)
403
 
404
  yield history
405
  return
406
 
407
  if model_type == "Qwen 2 0.5B (small, fast)":
 
 
408
  # Get the model and tokenizer, and tokenize the user text.
409
  model_inputs = tokenizer(text=full_prompt, return_tensors="pt", return_attention_mask=False).to(torch_device)
410
 
@@ -422,9 +503,9 @@ def produce_streaming_answer_chatbot(history,
422
  top_k=top_k
423
  )
424
 
425
- #print(generate_kwargs)
426
 
427
- t = Thread(target=model.generate, kwargs=generate_kwargs)
428
  t.start()
429
 
430
  # Pull the generated text from the streamer, and update the model output.
@@ -432,12 +513,14 @@ def produce_streaming_answer_chatbot(history,
432
  NUM_TOKENS=0
433
  print('-'*4+'Start Generation'+'-'*4)
434
 
435
- history[-1][1] = ""
 
436
  for new_text in streamer:
437
  try:
438
- if new_text == None: new_text = ""
439
- history[-1][1] += new_text
440
- NUM_TOKENS+=1
 
441
  yield history
442
  except Exception as e:
443
  print(f"Error during text generation: {e}")
@@ -463,14 +546,15 @@ def produce_streaming_answer_chatbot(history,
463
  NUM_TOKENS=0
464
  print('-'*4+'Start Generation'+'-'*4)
465
 
466
- output = model(
467
  full_prompt, **vars(gen_config))
468
 
469
- history[-1][1] = ""
 
470
  for out in output:
471
 
472
  if "choices" in out and len(out["choices"]) > 0 and "text" in out["choices"][0]:
473
- history[-1][1] += out["choices"][0]["text"]
474
  NUM_TOKENS+=1
475
  yield history
476
  else:
@@ -481,36 +565,75 @@ def produce_streaming_answer_chatbot(history,
481
  print('-'*4+'End Generation'+'-'*4)
482
  print(f'Num of generated tokens: {NUM_TOKENS}')
483
  print(f'Time for complete generation: {time_generate}s')
484
- print(f'Tokens per secound: {NUM_TOKENS/time_generate}')
485
  print(f'Time per token: {(time_generate/NUM_TOKENS)*1000}ms')
486
 
487
- elif model_type == "anthropic.claude-3-haiku-20240307-v1:0" or model_type == "anthropic.claude-3-sonnet-20240229-v1:0":
488
  system_prompt = "You are answering questions from the user based on source material. Respond with short, factually correct answers."
489
 
490
- try:
491
- print("Calling AWS Claude model")
492
- response = call_aws_claude(full_prompt, system_prompt, temperature, max_tokens, model_type)
493
- except Exception as e:
494
- # If fails, try again after 10 seconds in case there is a throttle limit
495
- print(e)
496
- try:
497
- out_message = "API limit hit - waiting 30 seconds to retry."
498
- print(out_message)
 
 
 
 
 
 
 
 
 
499
 
500
- time.sleep(30)
501
- response = call_aws_claude(full_prompt, system_prompt, temperature, max_tokens, model_type)
502
-
503
- except Exception as e:
504
- print(e)
505
- return "", history
506
  # Update the conversation history with the new prompt and response
507
- history.append({'role': 'user', 'parts': [full_prompt]})
508
- history.append({'role': 'assistant', 'parts': [response.text]})
509
 
510
- # Print the updated conversation history
511
- #print("conversation_history:", conversation_history)
512
 
513
- return response, history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
 
515
  # Chat helper functions
516
 
@@ -589,9 +712,6 @@ def hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val, out_p
589
 
590
  docs = vectorstore.similarity_search_with_score(new_question_kworded, k=k_val)
591
 
592
- print("Docs from similarity search:")
593
- print(docs)
594
-
595
  # Keep only documents with a certain score
596
  docs_len = [len(x[0].page_content) for x in docs]
597
  docs_scores = [x[1] for x in docs]
@@ -688,12 +808,8 @@ def hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val, out_p
688
  # 3rd level check on retrieved docs with SVM retriever
689
  # Check the type of the embeddings object
690
  embeddings_type = type(embeddings)
691
- print("Type of embeddings object:", embeddings_type)
692
-
693
 
694
- print("embeddings:", embeddings)
695
 
696
- from langchain_huggingface import HuggingFaceEmbeddings
697
  #hf_embeddings = HuggingFaceEmbeddings(**embeddings)
698
  hf_embeddings = embeddings
699
 
@@ -743,10 +859,6 @@ def hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val, out_p
743
  # Make df of best options
744
  doc_df = create_doc_df(docs_keep_out)
745
 
746
- print("doc_df:",doc_df)
747
- print("docs_keep_as_doc:",docs_keep_as_doc)
748
- print("docs_keep_out:", docs_keep_out)
749
-
750
  return docs_keep_as_doc, doc_df, docs_keep_out
751
 
752
  def get_expanded_passages(vectorstore, docs, width):
@@ -836,16 +948,16 @@ def get_expanded_passages(vectorstore, docs, width):
836
 
837
  return expanded_docs, doc_df
838
 
839
- def highlight_found_text(search_text: str, full_text: str, hlt_chunk_size:int=hlt_chunk_size, hlt_strat:List=hlt_strat, hlt_overlap:int=hlt_overlap) -> str:
840
  """
841
- Highlights occurrences of search_text within full_text.
842
 
843
  Parameters:
844
- - search_text (str): The text to be searched for within full_text.
845
- - full_text (str): The text within which search_text occurrences will be highlighted.
846
 
847
  Returns:
848
- - str: A string with occurrences of search_text highlighted.
849
 
850
  Example:
851
  >>> highlight_found_text("world", "Hello, world! This is a test. Another world awaits.")
@@ -859,32 +971,31 @@ def highlight_found_text(search_text: str, full_text: str, hlt_chunk_size:int=hl
859
  return text[i][0].replace(" ", " ").strip()
860
  else:
861
  return ""
 
 
 
 
 
 
862
 
863
- def extract_search_text_from_input(text):
864
- if isinstance(text, str):
865
- return text.replace(" ", " ").strip()
866
- elif isinstance(text, list):
867
- return text[-1][1].replace(" ", " ").strip()
868
- else:
869
- return ""
870
-
871
- full_text = extract_text_from_input(full_text)
872
- search_text = extract_search_text_from_input(search_text)
873
-
874
 
 
875
 
876
  text_splitter = RecursiveCharacterTextSplitter(
877
  chunk_size=hlt_chunk_size,
878
  separators=hlt_strat,
879
  chunk_overlap=hlt_overlap,
880
  )
881
- sections = text_splitter.split_text(search_text)
882
 
883
  found_positions = {}
884
  for x in sections:
885
  text_start_pos = 0
886
  while text_start_pos != -1:
887
- text_start_pos = full_text.find(x, text_start_pos)
888
  if text_start_pos != -1:
889
  found_positions[text_start_pos] = text_start_pos + len(x)
890
  text_start_pos += 1
@@ -907,20 +1018,24 @@ def highlight_found_text(search_text: str, full_text: str, hlt_chunk_size:int=hl
907
  prev_end = 0
908
  for start, end in combined_positions:
909
  if end-start > 15: # Only combine if there is a significant amount of matched text. Avoids picking up single words like 'and' etc.
910
- pos_tokens.append(full_text[prev_end:start])
911
- pos_tokens.append('<mark style="color:black;">' + full_text[start:end] + '</mark>')
912
  prev_end = end
913
- pos_tokens.append(full_text[prev_end:])
 
 
914
 
915
- return "".join(pos_tokens)
 
 
916
 
917
 
918
  # # Chat history functions
919
 
920
  def clear_chat(chat_history_state, sources, chat_message, current_topic):
921
- chat_history_state = []
922
  sources = ''
923
- chat_message = ''
924
  current_topic = ''
925
 
926
  return chat_history_state, sources, chat_message, current_topic
@@ -1011,8 +1126,7 @@ def remove_q_stopwords(question): # Remove stopwords from question. Not used at
1011
  for word in tokens_without_sw:
1012
  if word not in ordered_tokens:
1013
  ordered_tokens.add(word)
1014
- result.append(word)
1015
-
1016
 
1017
 
1018
  new_question_keywords = ' '.join(result)
@@ -1021,9 +1135,6 @@ def remove_q_stopwords(question): # Remove stopwords from question. Not used at
1021
  def remove_q_ner_extractor(question):
1022
 
1023
  predict_out = ner_model.predict(question)
1024
-
1025
-
1026
-
1027
  predict_tokens = [' '.join(v for k, v in d.items() if k == 'span') for d in predict_out]
1028
 
1029
  # Remove duplicate words while preserving order
@@ -1075,11 +1186,11 @@ def keybert_keywords(text, n, kw_model):
1075
  return keywords_list
1076
 
1077
  # Gradio functions
1078
- def turn_off_interactivity(user_message, history):
1079
- return gr.update(value="", interactive=False), history + [[user_message, None]]
1080
 
1081
  def restore_interactivity():
1082
- return gr.update(interactive=True)
1083
 
1084
  def update_message(dropdown_value):
1085
  return gr.Textbox(value=dropdown_value)
 
6
  from itertools import compress
7
  import pandas as pd
8
  import numpy as np
9
+ import google.generativeai as ai
10
+ from gradio import Progress
11
+ import boto3
12
+ import json
13
 
14
  # Model packages
15
  import torch.cuda
16
  from threading import Thread
17
  from transformers import pipeline, TextIteratorStreamer
18
+ from langchain_huggingface import HuggingFaceEmbeddings
19
 
20
  # Alternative model sources
21
  #from dataclasses import asdict, dataclass
 
27
  from langchain.text_splitter import RecursiveCharacterTextSplitter
28
  from langchain.docstore.document import Document
29
 
30
+ from chatfuncs.config import GEMINI_API_KEY, AWS_DEFAULT_REGION
31
+
32
+ model_object = [] # Define empty list for model functions to run
33
+ tokenizer = [] # Define empty list for model functions to run
34
+
35
+ from chatfuncs.model_load import temperature, max_new_tokens, sample, repetition_penalty, top_p, top_k, torch_device, CtransGenGenerationConfig, max_tokens
36
+
37
+ # ResponseObject class for AWS Bedrock calls
38
+ class ResponseObject:
39
+ def __init__(self, text, usage_metadata):
40
+ self.text = text
41
+ self.usage_metadata = usage_metadata
42
+
43
+ bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_DEFAULT_REGION)
44
+
45
  # For keyword extraction (not currently used)
46
  #import nltk
47
  #nltk.download('wordnet')
48
  from nltk.corpus import stopwords
49
  from nltk.tokenize import RegexpTokenizer
50
  from nltk.stem import WordNetLemmatizer
 
51
  from keybert import KeyBERT
52
 
53
  # For Name Entity Recognition model
54
  #from span_marker import SpanMarkerModel # Not currently used
55
 
 
56
  # For BM25 retrieval
57
  import bm25s
58
  import Stemmer
59
 
60
+ from chatfuncs.prompts import instruction_prompt_template_alpaca, instruction_prompt_mistral_orca, instruction_prompt_phi3, instruction_prompt_llama3, instruction_prompt_qwen, instruction_prompt_template_orca
 
 
 
 
 
 
 
61
 
62
  import gradio as gr
63
 
 
71
 
72
  max_memory_length = 0 # How long should the memory of the conversation last?
73
 
74
+ source_texts = "" # Define dummy source text (full text) just to enable highlight function to load
75
 
 
 
76
 
77
  ## Highlight text constants
78
  hlt_chunk_size = 12
 
86
  # Used to pull out keywords from chat history to add to user queries behind the scenes
87
  kw_model = pipeline("feature-extraction", model="sentence-transformers/all-MiniLM-L6-v2")
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  # Vectorstore funcs
90
 
91
  def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
 
118
 
119
  # Prompt functions
120
 
121
+ def base_prompt_templates(model_type:str = "Qwen 2 0.5B (small, fast)"):
122
 
123
  #EXAMPLE_PROMPT = PromptTemplate(
124
  # template="\nCONTENT:\n\n{page_content}\n\nSOURCE: {source}\n\n",
 
136
  INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_qwen, input_variables=['question', 'summaries'])
137
  elif model_type == "Phi 3.5 Mini (larger, slow)":
138
  INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_phi3, input_variables=['question', 'summaries'])
139
+ else:
140
+ INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_template_orca, input_variables=['question', 'summaries'])
141
+
142
 
143
  return INSTRUCTION_PROMPT, CONTENT_PROMPT
144
 
145
+ def write_out_metadata_as_string(metadata_in:str):
146
  metadata_string = [f"{' '.join(f'{k}: {v}' for k, v in d.items() if k != 'page_section')}" for d in metadata_in] # ['metadata']
147
  return metadata_string
148
 
149
+ def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt:str, content_prompt:str, extracted_memory:list, vectorstore:object, embeddings:object, relevant_flag:bool = True, out_passages:int = 2): # ,
150
 
151
  question = inputs["question"]
152
  chat_history = inputs["chat_history"]
 
 
 
153
 
154
  if relevant_flag == True:
155
  new_question_kworded = adapt_q_from_chat_history(question, chat_history, extracted_memory) # new_question_keywords,
 
165
  return sorry_prompt, "No relevant sources found.", new_question_kworded
166
 
167
  # Expand the found passages to the neighbouring context
 
 
168
  if 'meta_url' in doc_df.columns:
169
  file_type = determine_file_type(doc_df['meta_url'][0])
170
  else:
 
194
 
195
  return instruction_prompt_out, sources_docs_content_string, new_question_kworded
196
 
197
+ def create_full_prompt(user_input:str,
198
+ history:list[dict],
199
+ extracted_memory:str,
200
+ vectorstore:object,
201
+ embeddings:object,
202
+ model_type:str,
203
+ out_passages:list[str],
204
+ api_model_choice=None,
205
+ api_key=None,
206
+ relevant_flag = True):
207
 
208
  #if chain_agent is None:
209
  # history.append((user_input, "Please click the button to submit the Huggingface API key before using the chatbot (top right)"))
210
  # return history, history, "", ""
211
  print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
212
+
 
213
  history = history or []
214
 
215
  if api_model_choice and api_model_choice != "None":
 
235
  generate_expanded_prompt({"question": user_input, "chat_history": history}, #vectorstore,
236
  instruction_prompt, content_prompt, extracted_memory, vectorstore, embeddings, relevant_flag, out_passages)
237
 
238
+ history.append({"metadata":None, "options":None, "role": 'user', "content": user_input})
239
 
240
  print("Output history is:", history)
241
  print("Final prompt to model is:",instruction_prompt_out)
242
 
243
  return history, docs_content_string, instruction_prompt_out, relevant_flag
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tokens: int, model_choice: str) -> ResponseObject:
246
  """
247
  This function sends a request to AWS Claude with the following parameters:
 
270
  ],
271
  }
272
 
273
+ print("prompt_config:", prompt_config)
274
+
275
  body = json.dumps(prompt_config)
276
 
277
  modelId = model_choice
 
297
 
298
  return response
299
 
300
+ def construct_gemini_generative_model(in_api_key: str, temperature: float, model_choice: str, system_prompt: str, max_tokens: int) -> Tuple[object, dict]:
301
+ """
302
+ Constructs a GenerativeModel for Gemini API calls.
303
+
304
+ Parameters:
305
+ - in_api_key (str): The API key for authentication.
306
+ - temperature (float): The temperature parameter for the model, controlling the randomness of the output.
307
+ - model_choice (str): The choice of model to use for generation.
308
+ - system_prompt (str): The system prompt to guide the generation.
309
+ - max_tokens (int): The maximum number of tokens to generate.
310
+
311
+ Returns:
312
+ - Tuple[object, dict]: A tuple containing the constructed GenerativeModel and its configuration.
313
+ """
314
+ # Construct a GenerativeModel
315
+ try:
316
+ if in_api_key:
317
+ #print("Getting API key from textbox")
318
+ api_key = in_api_key
319
+ ai.configure(api_key=api_key)
320
+ elif "GOOGLE_API_KEY" in os.environ:
321
+ #print("Searching for API key in environmental variables")
322
+ api_key = os.environ["GOOGLE_API_KEY"]
323
+ ai.configure(api_key=api_key)
324
+ else:
325
+ print("No API key foound")
326
+ raise gr.Error("No API key found.")
327
+ except Exception as e:
328
+ print(e)
329
+
330
+ config = ai.GenerationConfig(temperature=temperature, max_output_tokens=max_tokens)
331
+
332
+ print("model_choice:", model_choice)
333
+
334
+ #model = ai.GenerativeModel.from_cached_content(cached_content=cache, generation_config=config)
335
+ model = ai.GenerativeModel(model_name=model_choice, system_instruction=system_prompt, generation_config=config)
336
+
337
+ return model, config
338
+
339
+ # Function to send a request and update history
340
+ def send_request(prompt: str, conversation_history: List[dict], model: object, config: dict, model_choice: str, system_prompt: str, temperature: float, progress=Progress(track_tqdm=True)) -> Tuple[str, List[dict]]:
341
+ """
342
+ This function sends a request to a language model with the given prompt, conversation history, model configuration, model choice, system prompt, and temperature.
343
+ It constructs the full prompt by appending the new user prompt to the conversation history, generates a response from the model, and updates the conversation history with the new prompt and response.
344
+ If the model choice is specific to AWS Claude, it calls the `call_aws_claude` function; otherwise, it uses the `model.generate_content` method.
345
+ The function returns the response text and the updated conversation history.
346
+ """
347
+ # Constructing the full prompt from the conversation history
348
+ full_prompt = "Conversation history:\n"
349
+
350
+ for entry in conversation_history:
351
+ role = entry['role'].capitalize() # Assuming the history is stored with 'role' and 'content'
352
+ message = ' '.join(entry['parts']) # Combining all parts of the message
353
+ full_prompt += f"{role}: {message}\n"
354
+
355
+ # Adding the new user prompt
356
+ full_prompt += f"\nUser: {prompt}"
357
+
358
+ # Print the full prompt for debugging purposes
359
+ #print("full_prompt:", full_prompt)
360
+
361
+ # Generate the model's response
362
+ if "gemini" in model_choice:
363
+ try:
364
+ response = model.generate_content(contents=full_prompt, generation_config=config)
365
+ except Exception as e:
366
+ # If fails, try again after 10 seconds in case there is a throttle limit
367
+ print(e)
368
+ try:
369
+ print("Calling Gemini model")
370
+ out_message = "API limit hit - waiting 30 seconds to retry."
371
+ print(out_message)
372
+ progress(0.5, desc=out_message)
373
+ time.sleep(30)
374
+ response = model.generate_content(contents=full_prompt, generation_config=config)
375
+ except Exception as e:
376
+ print(e)
377
+ return "", conversation_history
378
+ elif "claude" in model_choice:
379
+ try:
380
+ print("Calling AWS Claude model")
381
+ print("prompt:", prompt)
382
+ print("system_prompt:", system_prompt)
383
+ response = call_aws_claude(prompt, system_prompt, temperature, max_tokens, model_choice)
384
+ except Exception as e:
385
+ # If fails, try again after x seconds in case there is a throttle limit
386
+ print(e)
387
+ try:
388
+ out_message = "API limit hit - waiting 30 seconds to retry."
389
+ print(out_message)
390
+ progress(0.5, desc=out_message)
391
+ time.sleep(30)
392
+ response = call_aws_claude(prompt, system_prompt, temperature, max_tokens, model_choice)
393
+
394
+ except Exception as e:
395
+ print(e)
396
+ return "", conversation_history
397
+ else:
398
+ raise Exception("Model not found")
399
+
400
+ # Update the conversation history with the new prompt and response
401
+ conversation_history.append({"metadata":None, "options":None, "role": 'user', 'parts': [prompt]})
402
+ conversation_history.append({"metadata":None, "options":None, "role": "assistant", 'parts': [response.text]})
403
+
404
+ # Print the updated conversation history
405
+ #print("conversation_history:", conversation_history)
406
+
407
+ return response, conversation_history
408
+
409
+ def process_requests(prompts: List[str], system_prompt_with_table: str, conversation_history: List[dict], whole_conversation: List[str], whole_conversation_metadata: List[str], model: object, config: dict, model_choice: str, temperature: float, batch_no:int = 1, master:bool = False) -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
410
+ """
411
+ Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
412
+
413
+ Args:
414
+ prompts (List[str]): A list of prompts to be processed.
415
+ system_prompt_with_table (str): The system prompt including a table.
416
+ conversation_history (List[dict]): The history of the conversation.
417
+ whole_conversation (List[str]): The complete conversation including prompts and responses.
418
+ whole_conversation_metadata (List[str]): Metadata about the whole conversation.
419
+ model (object): The model to use for processing the prompts.
420
+ config (dict): Configuration for the model.
421
+ model_choice (str): The choice of model to use.
422
+ temperature (float): The temperature parameter for the model.
423
+ batch_no (int): Batch number of the large language model request.
424
+ master (bool): Is this request for the master table.
425
+
426
+ Returns:
427
+ Tuple[List[ResponseObject], List[dict], List[str], List[str]]: A tuple containing the list of responses, the updated conversation history, the updated whole conversation, and the updated whole conversation metadata.
428
+ """
429
+ responses = []
430
+ #for prompt in prompts:
431
+
432
+ response, conversation_history = send_request(prompts[0], conversation_history, model=model, config=config, model_choice=model_choice, system_prompt=system_prompt_with_table, temperature=temperature)
433
+
434
+ print(response.text)
435
+ #"Okay, I'm ready. What source are we discussing, and what's your question about it? Please provide as much context as possible so I can give you the best answer."]
436
+ print(response.usage_metadata)
437
+ responses.append(response)
438
+
439
+ # Create conversation txt object
440
+ whole_conversation.append(prompts[0])
441
+ whole_conversation.append(response.text)
442
+
443
+ # Create conversation metadata
444
+ if master == False:
445
+ whole_conversation_metadata.append(f"Query batch {batch_no} prompt {len(responses)} metadata:")
446
+ else:
447
+ whole_conversation_metadata.append(f"Query summary metadata:")
448
+
449
+ whole_conversation_metadata.append(str(response.usage_metadata))
450
+
451
+ return responses, conversation_history, whole_conversation, whole_conversation_metadata
452
+
453
+ def produce_streaming_answer_chatbot(
454
+ history:list,
455
+ full_prompt:str,
456
+ model_type:str,
457
+ temperature:float=temperature,
458
+ relevant_query_bool:bool=True,
459
+ chat_history:list[dict]=[{"metadata":None, "options":None, "role": 'user', "content": ""}],
460
+ max_new_tokens:int=max_new_tokens,
461
+ sample:bool=sample,
462
+ repetition_penalty:float=repetition_penalty,
463
+ top_p:float=top_p,
464
+ top_k:float=top_k,
465
+ max_tokens:int=max_tokens,
466
+ in_api_key:str=GEMINI_API_KEY
467
  ):
468
  #print("Model type is: ", model_type)
469
 
 
473
 
474
  # return history
475
 
476
+ history = chat_history
477
+
478
+ print("history at start of streaming function:", history)
479
 
480
  if relevant_query_bool == False:
481
+ history.append({"metadata":None, "options":None, "role": "assistant", "content": 'No relevant query found. Please retry your question'})
 
482
 
483
  yield history
484
  return
485
 
486
  if model_type == "Qwen 2 0.5B (small, fast)":
487
+
488
+ print("tokenizer:", tokenizer)
489
  # Get the model and tokenizer, and tokenize the user text.
490
  model_inputs = tokenizer(text=full_prompt, return_tensors="pt", return_attention_mask=False).to(torch_device)
491
 
 
503
  top_k=top_k
504
  )
505
 
506
+ print("model_object:", model_object)
507
 
508
+ t = Thread(target=model_object.generate, kwargs=generate_kwargs)
509
  t.start()
510
 
511
  # Pull the generated text from the streamer, and update the model output.
 
513
  NUM_TOKENS=0
514
  print('-'*4+'Start Generation'+'-'*4)
515
 
516
+ history.append({"metadata":None, "options":None, "role": "assistant", "content": ""})
517
+
518
  for new_text in streamer:
519
  try:
520
+ if new_text is None:
521
+ new_text = ""
522
+ history[-1]['content'] += new_text
523
+ NUM_TOKENS += 1
524
  yield history
525
  except Exception as e:
526
  print(f"Error during text generation: {e}")
 
546
  NUM_TOKENS=0
547
  print('-'*4+'Start Generation'+'-'*4)
548
 
549
+ output = model_object(
550
  full_prompt, **vars(gen_config))
551
 
552
+ history.append({"metadata":None, "options":None, "role": "assistant", "content": ""})
553
+
554
  for out in output:
555
 
556
  if "choices" in out and len(out["choices"]) > 0 and "text" in out["choices"][0]:
557
+ history[-1]['content'] += out["choices"][0]["text"]
558
  NUM_TOKENS+=1
559
  yield history
560
  else:
 
565
  print('-'*4+'End Generation'+'-'*4)
566
  print(f'Num of generated tokens: {NUM_TOKENS}')
567
  print(f'Time for complete generation: {time_generate}s')
568
+ print(f'Tokens per second: {NUM_TOKENS/time_generate}')
569
  print(f'Time per token: {(time_generate/NUM_TOKENS)*1000}ms')
570
 
571
+ elif "claude" in model_type:
572
  system_prompt = "You are answering questions from the user based on source material. Respond with short, factually correct answers."
573
 
574
+ print("full_prompt:", full_prompt)
575
+
576
+ if isinstance(full_prompt, str):
577
+ full_prompt = [full_prompt]
578
+
579
+ model = model_type
580
+ config = {}
581
+
582
+ responses, summary_conversation_history, whole_summary_conversation, whole_conversation_metadata = process_requests(full_prompt, system_prompt, conversation_history=[], whole_conversation=[], whole_conversation_metadata=[], model=model, config = config, model_choice = model_type, temperature = temperature)
583
+
584
+ if isinstance(responses[-1], ResponseObject):
585
+ response_texts = [resp.text for resp in responses]
586
+ elif "choices" in responses[-1]:
587
+ response_texts = [resp["choices"][0]['text'] for resp in responses]
588
+ else:
589
+ response_texts = [resp.text for resp in responses]
590
+
591
+ latest_response_text = response_texts[-1]
592
 
 
 
 
 
 
 
593
  # Update the conversation history with the new prompt and response
594
+ clean_text = re.sub(r'[\n\t\r]', ' ', latest_response_text) # Replace newlines, tabs, and carriage returns with a space
595
+ clean_response_text = re.sub(r'[^\x20-\x7E]', '', clean_text).strip() # Remove all non-ASCII printable characters
596
 
597
+ history.append({"metadata":None, "options":None, "role": "assistant", "content": ''})
 
598
 
599
+ for char in clean_response_text:
600
+ time.sleep(0.005)
601
+ history[-1]['content'] += char
602
+ yield history
603
+
604
+ elif "gemini" in model_type:
605
+ print("Using Gemini model:", model_type)
606
+ print("full_prompt:", full_prompt)
607
+
608
+ if isinstance(full_prompt, str):
609
+ full_prompt = [full_prompt]
610
+
611
+ system_prompt = "You are answering questions from the user based on source material. Respond with short, factually correct answers."
612
+
613
+ model, config = construct_gemini_generative_model(GEMINI_API_KEY, temperature, model_type, system_prompt, max_tokens)
614
+
615
+ responses, summary_conversation_history, whole_summary_conversation, whole_conversation_metadata = process_requests(full_prompt, system_prompt, conversation_history=[], whole_conversation=[], whole_conversation_metadata=[], model=model, config = config, model_choice = model_type, temperature = temperature)
616
+
617
+ if isinstance(responses[-1], ResponseObject):
618
+ response_texts = [resp.text for resp in responses]
619
+ elif "choices" in responses[-1]:
620
+ response_texts = [resp["choices"][0]['text'] for resp in responses]
621
+ else:
622
+ response_texts = [resp.text for resp in responses]
623
+
624
+ latest_response_text = response_texts[-1]
625
+
626
+ clean_text = re.sub(r'[\n\t\r]', ' ', latest_response_text) # Replace newlines, tabs, and carriage returns with a space
627
+ clean_response_text = re.sub(r'[^\x20-\x7E]', '', clean_text).strip() # Remove all non-ASCII printable characters
628
+
629
+ history.append({"metadata":None, "options":None, "role": "assistant", "content": ''})
630
+
631
+ for char in clean_response_text:
632
+ time.sleep(0.005)
633
+ history[-1]['content'] += char
634
+ yield history
635
+
636
+ print("history at end of function:", history)
637
 
638
  # Chat helper functions
639
 
 
712
 
713
  docs = vectorstore.similarity_search_with_score(new_question_kworded, k=k_val)
714
 
 
 
 
715
  # Keep only documents with a certain score
716
  docs_len = [len(x[0].page_content) for x in docs]
717
  docs_scores = [x[1] for x in docs]
 
808
  # 3rd level check on retrieved docs with SVM retriever
809
  # Check the type of the embeddings object
810
  embeddings_type = type(embeddings)
 
 
811
 
 
812
 
 
813
  #hf_embeddings = HuggingFaceEmbeddings(**embeddings)
814
  hf_embeddings = embeddings
815
 
 
859
  # Make df of best options
860
  doc_df = create_doc_df(docs_keep_out)
861
 
 
 
 
 
862
  return docs_keep_as_doc, doc_df, docs_keep_out
863
 
864
  def get_expanded_passages(vectorstore, docs, width):
 
948
 
949
  return expanded_docs, doc_df
950
 
951
+ def highlight_found_text(chat_history: list[dict], source_texts: list[dict], hlt_chunk_size:int=hlt_chunk_size, hlt_strat:List=hlt_strat, hlt_overlap:int=hlt_overlap) -> str:
952
  """
953
+ Highlights occurrences of chat_history within source_texts.
954
 
955
  Parameters:
956
+ - chat_history (str): The text to be searched for within source_texts.
957
+ - source_texts (str): The text within which chat_history occurrences will be highlighted.
958
 
959
  Returns:
960
+ - str: A string with occurrences of chat_history highlighted.
961
 
962
  Example:
963
  >>> highlight_found_text("world", "Hello, world! This is a test. Another world awaits.")
 
971
  return text[i][0].replace(" ", " ").strip()
972
  else:
973
  return ""
974
+
975
+ print("chat_history:", chat_history)
976
+
977
+ response_text = next(
978
+ (entry['content'] for entry in reversed(chat_history) if entry.get('role') == 'assistant'),
979
+ "")
980
 
981
+ print("response_text:", response_text)
982
+
983
+ source_texts = extract_text_from_input(source_texts)
 
 
 
 
 
 
 
 
984
 
985
+ print("source_texts:", source_texts)
986
 
987
  text_splitter = RecursiveCharacterTextSplitter(
988
  chunk_size=hlt_chunk_size,
989
  separators=hlt_strat,
990
  chunk_overlap=hlt_overlap,
991
  )
992
+ sections = text_splitter.split_text(response_text)
993
 
994
  found_positions = {}
995
  for x in sections:
996
  text_start_pos = 0
997
  while text_start_pos != -1:
998
+ text_start_pos = source_texts.find(x, text_start_pos)
999
  if text_start_pos != -1:
1000
  found_positions[text_start_pos] = text_start_pos + len(x)
1001
  text_start_pos += 1
 
1018
  prev_end = 0
1019
  for start, end in combined_positions:
1020
  if end-start > 15: # Only combine if there is a significant amount of matched text. Avoids picking up single words like 'and' etc.
1021
+ pos_tokens.append(source_texts[prev_end:start])
1022
+ pos_tokens.append('<mark style="color:black;">' + source_texts[start:end] + '</mark>')
1023
  prev_end = end
1024
+ pos_tokens.append(source_texts[prev_end:])
1025
+
1026
+ out_pos_tokens = "".join(pos_tokens)
1027
 
1028
+ print("out_pos_tokens:", out_pos_tokens)
1029
+
1030
+ return out_pos_tokens
1031
 
1032
 
1033
  # # Chat history functions
1034
 
1035
  def clear_chat(chat_history_state, sources, chat_message, current_topic):
1036
+ chat_history_state = None
1037
  sources = ''
1038
+ chat_message = None
1039
  current_topic = ''
1040
 
1041
  return chat_history_state, sources, chat_message, current_topic
 
1126
  for word in tokens_without_sw:
1127
  if word not in ordered_tokens:
1128
  ordered_tokens.add(word)
1129
+ result.append(word)
 
1130
 
1131
 
1132
  new_question_keywords = ' '.join(result)
 
1135
  def remove_q_ner_extractor(question):
1136
 
1137
  predict_out = ner_model.predict(question)
 
 
 
1138
  predict_tokens = [' '.join(v for k, v in d.items() if k == 'span') for d in predict_out]
1139
 
1140
  # Remove duplicate words while preserving order
 
1186
  return keywords_list
1187
 
1188
  # Gradio functions
1189
+ def turn_off_interactivity():
1190
+ return gr.Textbox(interactive=False), gr.Button(interactive=False)
1191
 
1192
  def restore_interactivity():
1193
+ return gr.Textbox(interactive=True), gr.Button(interactive=True)
1194
 
1195
  def update_message(dropdown_value):
1196
  return gr.Textbox(value=dropdown_value)
chatfuncs/config.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import socket
4
+ import logging
5
+ from datetime import datetime
6
+ from dotenv import load_dotenv
7
+
8
+ today_rev = datetime.now().strftime("%Y%m%d")
9
+ HOST_NAME = socket.gethostname()
10
+
11
+ # Set or retrieve configuration variables for the redaction app
12
+
13
+ def get_or_create_env_var(var_name:str, default_value:str, print_val:bool=False):
14
+ '''
15
+ Get an environmental variable, and set it to a default value if it doesn't exist
16
+ '''
17
+ # Get the environment variable if it exists
18
+ value = os.environ.get(var_name)
19
+
20
+ # If it doesn't exist, set the environment variable to the default value
21
+ if value is None:
22
+ os.environ[var_name] = default_value
23
+ value = default_value
24
+
25
+ if print_val == True:
26
+ print(f'The value of {var_name} is {value}')
27
+
28
+ return value
29
+
30
+ def ensure_folder_exists(output_folder:str):
31
+ """Checks if the specified folder exists, creates it if not."""
32
+
33
+ if not os.path.exists(output_folder):
34
+ # Create the folder if it doesn't exist
35
+ os.makedirs(output_folder, exist_ok=True)
36
+ print(f"Created the {output_folder} folder.")
37
+ else:
38
+ print(f"The {output_folder} folder already exists.")
39
+
40
+ def add_folder_to_path(folder_path: str):
41
+ '''
42
+ Check if a folder exists on your system. If so, get the absolute path and then add it to the system Path variable if it doesn't already exist. Function is only relevant for locally-created executable files based on this app (when using pyinstaller it creates a _internal folder that contains tesseract and poppler. These need to be added to the system path to enable the app to run)
43
+ '''
44
+
45
+ if os.path.exists(folder_path) and os.path.isdir(folder_path):
46
+ print(folder_path, "folder exists.")
47
+
48
+ # Resolve relative path to absolute path
49
+ absolute_path = os.path.abspath(folder_path)
50
+
51
+ current_path = os.environ['PATH']
52
+ if absolute_path not in current_path.split(os.pathsep):
53
+ full_path_extension = absolute_path + os.pathsep + current_path
54
+ os.environ['PATH'] = full_path_extension
55
+ #print(f"Updated PATH with: ", full_path_extension)
56
+ else:
57
+ print(f"Directory {folder_path} already exists in PATH.")
58
+ else:
59
+ print(f"Folder not found at {folder_path} - not added to PATH")
60
+
61
+ ensure_folder_exists("config/")
62
+
63
+ # If you have an aws_config env file in the config folder, you can load in app variables this way, e.g. 'config/app_config.env'
64
+ APP_CONFIG_PATH = get_or_create_env_var('APP_CONFIG_PATH', 'config/app_config.env') # e.g. config/app_config.env
65
+
66
+ if APP_CONFIG_PATH:
67
+ if os.path.exists(APP_CONFIG_PATH):
68
+ print(f"Loading app variables from config file {APP_CONFIG_PATH}")
69
+ load_dotenv(APP_CONFIG_PATH)
70
+ else: print("App config file not found at location:", APP_CONFIG_PATH)
71
+
72
+ # Report logging to console?
73
+ LOGGING = get_or_create_env_var('LOGGING', 'False')
74
+
75
+ if LOGGING == 'True':
76
+ # Configure logging
77
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
78
+
79
+ ###
80
+ # AWS CONFIG
81
+ ###
82
+
83
+ # If you have an aws_config env file in the config folder, you can load in AWS keys this way, e.g. 'env/aws_config.env'
84
+ AWS_CONFIG_PATH = get_or_create_env_var('AWS_CONFIG_PATH', '') # e.g. config/aws_config.env
85
+
86
+ if AWS_CONFIG_PATH:
87
+ if os.path.exists(AWS_CONFIG_PATH):
88
+ print(f"Loading AWS variables from config file {AWS_CONFIG_PATH}")
89
+ load_dotenv(AWS_CONFIG_PATH)
90
+ else: print("AWS config file not found at location:", AWS_CONFIG_PATH)
91
+
92
+ RUN_AWS_FUNCTIONS = get_or_create_env_var("RUN_AWS_FUNCTIONS", "0")
93
+
94
+ AWS_REGION = get_or_create_env_var('AWS_REGION', '')
95
+
96
+ AWS_DEFAULT_REGION = get_or_create_env_var('AWS_DEFAULT_REGION', '')
97
+
98
+ AWS_CLIENT_ID = get_or_create_env_var('AWS_CLIENT_ID', '')
99
+
100
+ AWS_CLIENT_SECRET = get_or_create_env_var('AWS_CLIENT_SECRET', '')
101
+
102
+ AWS_USER_POOL_ID = get_or_create_env_var('AWS_USER_POOL_ID', '')
103
+
104
+ AWS_ACCESS_KEY = get_or_create_env_var('AWS_ACCESS_KEY', '')
105
+ if AWS_ACCESS_KEY: print(f'AWS_ACCESS_KEY found in environment variables')
106
+
107
+ AWS_SECRET_KEY = get_or_create_env_var('AWS_SECRET_KEY', '')
108
+ if AWS_SECRET_KEY: print(f'AWS_SECRET_KEY found in environment variables')
109
+
110
+ DOCUMENT_REDACTION_BUCKET = get_or_create_env_var('DOCUMENT_REDACTION_BUCKET', '')
111
+
112
+ # Custom headers e.g. if routing traffic through Cloudfront
113
+ # Retrieving or setting CUSTOM_HEADER
114
+ CUSTOM_HEADER = get_or_create_env_var('CUSTOM_HEADER', '')
115
+ #if CUSTOM_HEADER: print(f'CUSTOM_HEADER found')
116
+
117
+ # Retrieving or setting CUSTOM_HEADER_VALUE
118
+ CUSTOM_HEADER_VALUE = get_or_create_env_var('CUSTOM_HEADER_VALUE', '')
119
+ #if CUSTOM_HEADER_VALUE: print(f'CUSTOM_HEADER_VALUE found')
120
+
121
+ ###
122
+ # File I/O config
123
+ ###
124
+ SESSION_OUTPUT_FOLDER = get_or_create_env_var('SESSION_OUTPUT_FOLDER', 'False') # i.e. do you want your input and output folders saved within a subfolder based on session hash value within output/input folders
125
+
126
+ OUTPUT_FOLDER = get_or_create_env_var('GRADIO_OUTPUT_FOLDER', 'output/') # 'output/'
127
+ INPUT_FOLDER = get_or_create_env_var('GRADIO_INPUT_FOLDER', 'input/') # 'input/'
128
+
129
+ ensure_folder_exists(OUTPUT_FOLDER)
130
+ ensure_folder_exists(INPUT_FOLDER)
131
+
132
+ # Allow for files to be saved in a temporary folder for increased security in some instances
133
+ if OUTPUT_FOLDER == "TEMP" or INPUT_FOLDER == "TEMP":
134
+ # Create a temporary directory
135
+ with tempfile.TemporaryDirectory() as temp_dir:
136
+ print(f'Temporary directory created at: {temp_dir}')
137
+
138
+ if OUTPUT_FOLDER == "TEMP": OUTPUT_FOLDER = temp_dir + "/"
139
+ if INPUT_FOLDER == "TEMP": INPUT_FOLDER = temp_dir + "/"
140
+
141
+ # By default, logs are put into a subfolder of today's date and the host name of the instance running the app. This is to avoid at all possible the possibility of log files from one instance overwriting the logs of another instance on S3. If running the app on one system always, or just locally, it is not necessary to make the log folders so specific.
142
+ # Another way to address this issue would be to write logs to another type of storage, e.g. database such as dynamodb. I may look into this in future.
143
+
144
+ USE_LOG_SUBFOLDERS = get_or_create_env_var('USE_LOG_SUBFOLDERS', 'True')
145
+
146
+ if USE_LOG_SUBFOLDERS == "True":
147
+ day_log_subfolder = today_rev + '/'
148
+ host_name_subfolder = HOST_NAME + '/'
149
+ full_log_subfolder = day_log_subfolder + host_name_subfolder
150
+ else:
151
+ full_log_subfolder = ""
152
+
153
+ FEEDBACK_LOGS_FOLDER = get_or_create_env_var('FEEDBACK_LOGS_FOLDER', 'feedback/' + full_log_subfolder)
154
+ ACCESS_LOGS_FOLDER = get_or_create_env_var('ACCESS_LOGS_FOLDER', 'logs/' + full_log_subfolder)
155
+ USAGE_LOGS_FOLDER = get_or_create_env_var('USAGE_LOGS_FOLDER', 'usage/' + full_log_subfolder)
156
+
157
+ ensure_folder_exists(FEEDBACK_LOGS_FOLDER)
158
+ ensure_folder_exists(ACCESS_LOGS_FOLDER)
159
+ ensure_folder_exists(USAGE_LOGS_FOLDER)
160
+
161
+ # Should the redacted file name be included in the logs? In some instances, the names of the files themselves could be sensitive, and should not be disclosed beyond the app. So, by default this is false.
162
+ DISPLAY_FILE_NAMES_IN_LOGS = get_or_create_env_var('DISPLAY_FILE_NAMES_IN_LOGS', 'False')
163
+
164
+ ###
165
+ # RUN CONFIG
166
+ GEMINI_API_KEY = get_or_create_env_var('GEMINI_API_KEY', '')
167
+
168
+
169
+ # Number of pages to loop through before breaking the function and restarting from the last finished page (not currently activated).
170
+ PAGE_BREAK_VALUE = get_or_create_env_var('PAGE_BREAK_VALUE', '99999')
171
+
172
+ MAX_TIME_VALUE = get_or_create_env_var('MAX_TIME_VALUE', '999999')
173
+
174
+ ###
175
+ # APP RUN CONFIG
176
+ ###
177
+
178
+ # Get some environment variables and Launch the Gradio app
179
+ COGNITO_AUTH = get_or_create_env_var('COGNITO_AUTH', '0')
180
+
181
+ RUN_DIRECT_MODE = get_or_create_env_var('RUN_DIRECT_MODE', '0')
182
+
183
+ MAX_QUEUE_SIZE = int(get_or_create_env_var('MAX_QUEUE_SIZE', '5'))
184
+
185
+ MAX_FILE_SIZE = get_or_create_env_var('MAX_FILE_SIZE', '250mb')
186
+
187
+ GRADIO_SERVER_PORT = int(get_or_create_env_var('GRADIO_SERVER_PORT', '7860'))
188
+
189
+ ROOT_PATH = get_or_create_env_var('ROOT_PATH', '')
190
+
191
+ DEFAULT_CONCURRENCY_LIMIT = get_or_create_env_var('DEFAULT_CONCURRENCY_LIMIT', '3')
192
+
193
+ GET_DEFAULT_ALLOW_LIST = get_or_create_env_var('GET_DEFAULT_ALLOW_LIST', 'False')
194
+
195
+ ALLOW_LIST_PATH = get_or_create_env_var('ALLOW_LIST_PATH', '') # config/default_allow_list.csv
196
+
197
+ S3_ALLOW_LIST_PATH = get_or_create_env_var('S3_ALLOW_LIST_PATH', '') # default_allow_list.csv # This is a path within the DOCUMENT_REDACTION_BUCKET
198
+
199
+ if ALLOW_LIST_PATH: OUTPUT_ALLOW_LIST_PATH = ALLOW_LIST_PATH
200
+ else: OUTPUT_ALLOW_LIST_PATH = 'config/default_allow_list.csv'
201
+
202
+ SHOW_COSTS = get_or_create_env_var('SHOW_COSTS', 'False')
203
+
204
+ GET_COST_CODES = get_or_create_env_var('GET_COST_CODES', 'False')
205
+
206
+ DEFAULT_COST_CODE = get_or_create_env_var('DEFAULT_COST_CODE', '')
207
+
208
+ COST_CODES_PATH = get_or_create_env_var('COST_CODES_PATH', '') # 'config/COST_CENTRES.csv' # file should be a csv file with a single table in it that has two columns with a header. First column should contain cost codes, second column should contain a name or description for the cost code
209
+
210
+ S3_COST_CODES_PATH = get_or_create_env_var('S3_COST_CODES_PATH', '') # COST_CENTRES.csv # This is a path within the DOCUMENT_REDACTION_BUCKET
211
+
212
+ if COST_CODES_PATH: OUTPUT_COST_CODES_PATH = COST_CODES_PATH
213
+ else: OUTPUT_COST_CODES_PATH = 'config/COST_CENTRES.csv'
214
+
215
+ ENFORCE_COST_CODES = get_or_create_env_var('ENFORCE_COST_CODES', 'False') # If you have cost codes listed, is it compulsory to choose one before redacting?
216
+
217
+ if ENFORCE_COST_CODES == 'True': GET_COST_CODES = 'True'
chatfuncs/model_load.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # Currently set gpu_layers to 0 even with cuda due to persistent bugs in implementation with cuda
4
+ if torch.cuda.is_available():
5
+ torch_device = "cuda"
6
+ gpu_layers = 100
7
+ else:
8
+ torch_device = "cpu"
9
+ gpu_layers = 0
10
+
11
+ print("Running on device:", torch_device)
12
+ threads = 8 #torch.get_num_threads()
13
+ print("CPU threads:", threads)
14
+
15
+ # Qwen 2 0.5B (small, fast) Model parameters
16
+ temperature: float = 0.1
17
+ top_k: int = 3
18
+ top_p: float = 1
19
+ repetition_penalty: float = 1.15
20
+ flan_alpaca_repetition_penalty: float = 1.3
21
+ last_n_tokens: int = 64
22
+ max_new_tokens: int = 1024
23
+ seed: int = 42
24
+ reset: bool = False
25
+ stream: bool = True
26
+ threads: int = threads
27
+ batch_size:int = 256
28
+ context_length:int = 2048
29
+ sample = True
30
+
31
+ # Bedrock parameters
32
+ max_tokens = 4096
33
+
34
+
35
+ class CtransInitConfig_gpu:
36
+ def __init__(self,
37
+ last_n_tokens=last_n_tokens,
38
+ seed=seed,
39
+ n_threads=threads,
40
+ n_batch=batch_size,
41
+ n_ctx=max_tokens,
42
+ n_gpu_layers=gpu_layers):
43
+
44
+ self.last_n_tokens = last_n_tokens
45
+ self.seed = seed
46
+ self.n_threads = n_threads
47
+ self.n_batch = n_batch
48
+ self.n_ctx = n_ctx
49
+ self.n_gpu_layers = n_gpu_layers
50
+ # self.stop: list[str] = field(default_factory=lambda: [stop_string])
51
+
52
+ def update_gpu(self, new_value):
53
+ self.n_gpu_layers = new_value
54
+
55
+ class CtransInitConfig_cpu(CtransInitConfig_gpu):
56
+ def __init__(self):
57
+ super().__init__()
58
+ self.n_gpu_layers = 0
59
+
60
+ gpu_config = CtransInitConfig_gpu()
61
+ cpu_config = CtransInitConfig_cpu()
62
+
63
+
64
+ class CtransGenGenerationConfig:
65
+ def __init__(self, temperature=temperature,
66
+ top_k=top_k,
67
+ top_p=top_p,
68
+ repeat_penalty=repetition_penalty,
69
+ seed=seed,
70
+ stream=stream,
71
+ max_tokens=max_new_tokens
72
+ ):
73
+ self.temperature = temperature
74
+ self.top_k = top_k
75
+ self.top_p = top_p
76
+ self.repeat_penalty = repeat_penalty
77
+ self.seed = seed
78
+ self.max_tokens=max_tokens
79
+ self.stream = stream
80
+
81
+ def update_temp(self, new_value):
82
+ self.temperature = new_value
chatfuncs/prompts.py CHANGED
@@ -23,8 +23,7 @@ QUESTION - {question}
23
  """
24
 
25
 
26
- instruction_prompt_template_orca = """
27
- ### System:
28
  You are an AI assistant that follows instruction extremely well. Help as much as you can.
29
  ### User:
30
  Answer the QUESTION with a short response using information from the following CONTENT.
@@ -33,8 +32,7 @@ CONTENT: {summaries}
33
 
34
  ### Response:"""
35
 
36
- instruction_prompt_template_orca_quote = """
37
- ### System:
38
  You are an AI assistant that follows instruction extremely well. Help as much as you can.
39
  ### User:
40
  Quote text from the CONTENT to answer the QUESTION below.
 
23
  """
24
 
25
 
26
+ instruction_prompt_template_orca = """### System:
 
27
  You are an AI assistant that follows instruction extremely well. Help as much as you can.
28
  ### User:
29
  Answer the QUESTION with a short response using information from the following CONTENT.
 
32
 
33
  ### Response:"""
34
 
35
+ instruction_prompt_template_orca_quote = """### System:
 
36
  You are an AI assistant that follows instruction extremely well. Help as much as you can.
37
  ### User:
38
  Quote text from the CONTENT to answer the QUESTION below.
requirements.txt CHANGED
@@ -22,3 +22,4 @@ PyStemmer==2.2.0.3
22
  scipy==1.15.2
23
  numpy==1.26.4
24
  boto3==1.38.0
 
 
22
  scipy==1.15.2
23
  numpy==1.26.4
24
  boto3==1.38.0
25
+ python-dotenv==1.1.0
requirements_gpu.txt CHANGED
@@ -20,4 +20,5 @@ bm25s==0.2.12
20
  PyStemmer==2.2.0.3
21
  scipy==1.15.2
22
  numpy==1.26.4
23
- boto3==1.38.0
 
 
20
  PyStemmer==2.2.0.3
21
  scipy==1.15.2
22
  numpy==1.26.4
23
+ boto3==1.38.0
24
+ python-dotenv==1.1.0