Commit
·
ee7464e
1
Parent(s):
0c818aa
Initial compatibility tested for use with Gemini and AWS Bedrock APIs
Browse files- .dockerignore +12 -0
- .gitignore +2 -1
- app.py +114 -115
- chatfuncs/auth.py +40 -13
- chatfuncs/chatfuncs.py +319 -208
- chatfuncs/config.py +217 -0
- chatfuncs/model_load.py +82 -0
- chatfuncs/prompts.py +2 -4
- requirements.txt +1 -0
- requirements_gpu.txt +2 -1
.dockerignore
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pyc
|
2 |
+
*.ipynb
|
3 |
+
*.pdf
|
4 |
+
*.spec
|
5 |
+
*.toc
|
6 |
+
*.csv
|
7 |
+
*.bin
|
8 |
+
bootstrapper.py
|
9 |
+
build/*
|
10 |
+
dist/*
|
11 |
+
test/*
|
12 |
+
config/*
|
.gitignore
CHANGED
@@ -8,4 +8,5 @@
|
|
8 |
bootstrapper.py
|
9 |
build/*
|
10 |
dist/*
|
11 |
-
test/*
|
|
|
|
8 |
bootstrapper.py
|
9 |
build/*
|
10 |
dist/*
|
11 |
+
test/*
|
12 |
+
config/*
|
app.py
CHANGED
@@ -1,44 +1,35 @@
|
|
1 |
-
# Load in packages
|
2 |
-
|
3 |
import os
|
4 |
-
import socket
|
5 |
-
|
6 |
from typing import Type
|
7 |
-
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
|
8 |
from langchain_community.vectorstores import FAISS
|
9 |
import gradio as gr
|
10 |
import pandas as pd
|
11 |
|
12 |
-
from transformers import AutoTokenizer
|
13 |
-
import torch
|
14 |
-
|
15 |
-
from llama_cpp import Llama
|
16 |
-
from huggingface_hub import hf_hub_download
|
17 |
from chatfuncs.ingest import embed_faiss_save_to_zip
|
18 |
-
from chatfuncs.helper_functions import get_or_create_env_var
|
19 |
|
20 |
-
from chatfuncs.helper_functions import ensure_output_folder_exists, get_connection_params, output_folder,
|
21 |
from chatfuncs.aws_functions import upload_file_to_s3
|
22 |
-
#from chatfuncs.llm_api_call import llm_query
|
23 |
from chatfuncs.auth import authenticate_user
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
PandasDataFrame = Type[pd.DataFrame]
|
26 |
|
27 |
from datetime import datetime
|
28 |
today_rev = datetime.now().strftime("%Y%m%d")
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
access_logs_data_folder = 'logs/' + today_rev + '/' + host_name + '/'
|
35 |
-
feedback_data_folder = 'feedback/' + today_rev + '/' + host_name + '/'
|
36 |
-
usage_data_folder = 'usage/' + today_rev + '/' + host_name + '/'
|
37 |
|
38 |
# Disable cuda devices if necessary
|
39 |
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
40 |
|
41 |
-
#from chatfuncs.chatfuncs import *
|
42 |
import chatfuncs.ingest as ing
|
43 |
|
44 |
###
|
@@ -73,21 +64,44 @@ def get_faiss_store(faiss_vstore_folder,embeddings):
|
|
73 |
return vectorstore
|
74 |
|
75 |
import chatfuncs.chatfuncs as chatf
|
|
|
76 |
|
77 |
chatf.embeddings = load_embeddings(embeddings_name)
|
78 |
chatf.vectorstore = get_faiss_store(faiss_vstore_folder="faiss_embedding",embeddings=globals()["embeddings"])
|
79 |
|
|
|
80 |
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
cpu_config = chatf.cpu_config
|
89 |
-
if torch_device is None:
|
90 |
-
torch_device = chatf.torch_device
|
91 |
|
92 |
if model_type == "Phi 3.5 Mini (larger, slow)":
|
93 |
if torch_device == "cuda":
|
@@ -112,8 +126,7 @@ def load_model(model_type, gpu_layers, gpu_config=None, cpu_config=None, torch_d
|
|
112 |
)
|
113 |
|
114 |
except Exception as e:
|
115 |
-
print("GPU load failed")
|
116 |
-
print(e)
|
117 |
model = Llama(
|
118 |
model_path=hf_hub_download(
|
119 |
repo_id=os.environ.get("REPO_ID", "QuantFactory/Phi-3.5-mini-instruct-GGUF"), #"QuantFactory/Phi-3-mini-128k-instruct-GGUF"), #, "microsoft/Phi-3-mini-4k-instruct-gguf"),#"QuantFactory/Meta-Llama-3-8B-Instruct-GGUF-v2"), #"microsoft/Phi-3-mini-4k-instruct-gguf"),#"TheBloke/Mistral-7B-OpenOrca-GGUF"),
|
@@ -128,57 +141,21 @@ def load_model(model_type, gpu_layers, gpu_config=None, cpu_config=None, torch_d
|
|
128 |
# Huggingface chat model
|
129 |
hf_checkpoint = 'Qwen/Qwen2-0.5B-Instruct'# 'declare-lab/flan-alpaca-large'#'declare-lab/flan-alpaca-base' # # # 'Qwen/Qwen1.5-0.5B-Chat' #
|
130 |
|
131 |
-
|
132 |
-
|
133 |
-
from transformers import AutoModelForSeq2SeqLM, AutoModelForCausalLM
|
134 |
-
|
135 |
-
if torch_device == "cuda":
|
136 |
-
if "flan" in model_name:
|
137 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")#, torch_dtype=torch.float16)
|
138 |
-
else:
|
139 |
-
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")#, torch_dtype=torch.float16)
|
140 |
-
else:
|
141 |
-
if "flan" in model_name:
|
142 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)#, torch_dtype=torch.float16)
|
143 |
-
else:
|
144 |
-
model = AutoModelForCausalLM.from_pretrained(model_name)#, trust_remote_code=True)#, torch_dtype=torch.float16)
|
145 |
-
|
146 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length = chatf.context_length)
|
147 |
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
|
152 |
-
chatf.
|
153 |
chatf.tokenizer = tokenizer
|
154 |
chatf.model_type = model_type
|
155 |
|
156 |
load_confirmation = "Finished loading model: " + model_type
|
157 |
|
158 |
print(load_confirmation)
|
159 |
-
return model_type, load_confirmation, model_type
|
160 |
|
161 |
-
|
162 |
-
#model_type = "Phi 3.5 Mini (larger, slow)"
|
163 |
-
#load_model(model_type, chatf.gpu_layers, chatf.gpu_config, chatf.cpu_config, chatf.torch_device)
|
164 |
-
|
165 |
-
model_type = "Qwen 2 0.5B (small, fast)"
|
166 |
-
load_model(model_type, 0, chatf.gpu_config, chatf.cpu_config, chatf.torch_device)
|
167 |
-
|
168 |
-
def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
|
169 |
-
|
170 |
-
print(f"> Total split documents: {len(docs_out)}")
|
171 |
-
|
172 |
-
print(docs_out)
|
173 |
-
|
174 |
-
vectorstore_func = FAISS.from_documents(documents=docs_out, embedding=embeddings)
|
175 |
-
|
176 |
-
chatf.vectorstore = vectorstore_func
|
177 |
-
|
178 |
-
out_message = "Document processing complete"
|
179 |
-
|
180 |
-
return out_message, vectorstore_func
|
181 |
-
# Gradio chat
|
182 |
|
183 |
|
184 |
###
|
@@ -188,17 +165,29 @@ def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
|
|
188 |
app = gr.Blocks(theme = gr.themes.Base(), fill_width=True)#css=".gradio-container {background-color: black}")
|
189 |
|
190 |
with app:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
ingest_text = gr.State()
|
192 |
ingest_metadata = gr.State()
|
193 |
ingest_docs = gr.State()
|
194 |
|
195 |
model_type_state = gr.State(model_type)
|
|
|
|
|
|
|
196 |
embeddings_state = gr.State(chatf.embeddings)#globals()["embeddings"])
|
197 |
vectorstore_state = gr.State(chatf.vectorstore)#globals()["vectorstore"])
|
198 |
|
199 |
relevant_query_state = gr.Checkbox(value=True, visible=False)
|
200 |
|
201 |
-
model_state = gr.State() # chatf.
|
202 |
tokenizer_state = gr.State() # chatf.tokenizer (gives error)
|
203 |
|
204 |
chat_history_state = gr.State()
|
@@ -222,7 +211,7 @@ with app:
|
|
222 |
gr.Markdown("Chat with PDF, web page or (new) csv/Excel documents. The default is a small model (Qwen 2 0.5B), that can only answer specific questions that are answered in the text. It cannot give overall impressions of, or summarise the document. The alternative (Phi 3.5 Mini (larger, slow)), can reason a little better, but is much slower (See Advanced tab).\n\nBy default the Lambeth Borough Plan '[Lambeth 2030 : Our Future, Our Lambeth](https://www.lambeth.gov.uk/better-fairer-lambeth/projects/lambeth-2030-our-future-our-lambeth)' is loaded. If you want to talk about another document or web page, please select from the second tab. If switching topic, please click the 'Clear chat' button.\n\nCaution: This is a public app. Please ensure that the document you upload is not sensitive is any way as other users may see it! Also, please note that LLM chatbots may give incomplete or incorrect information, so please use with care.")
|
223 |
|
224 |
with gr.Accordion(label="Use Gemini or AWS Claude model", open=False, visible=False):
|
225 |
-
api_model_choice = gr.Dropdown(value = "None", choices = ["gemini-
|
226 |
in_api_key = gr.Textbox(value = "", label="Enter Gemini API key (only if using Google API models)", lines=1, type="password",interactive=True, visible=False)
|
227 |
|
228 |
with gr.Row():
|
@@ -233,7 +222,7 @@ with app:
|
|
233 |
|
234 |
with gr.Row():
|
235 |
#chat_height = 500
|
236 |
-
chatbot = gr.Chatbot(avatar_images=('user.jfif', 'bot.jpg'), scale = 1, resizable=True, show_copy_all_button=True, show_copy_button=True, show_share_button=True, type='
|
237 |
with gr.Accordion("Open this tab to see the source paragraphs used to generate the answer", open = True):
|
238 |
sources = gr.HTML(value = "Source paragraphs with the most relevant text will appear here") # , height=chat_height
|
239 |
|
@@ -281,13 +270,12 @@ with app:
|
|
281 |
out_passages = gr.Slider(minimum=1, value = 2, maximum=10, step=1, label="Choose number of passages to retrieve from the document. Numbers greater than 2 may lead to increased hallucinations or input text being truncated.")
|
282 |
temp_slide = gr.Slider(minimum=0.1, value = 0.5, maximum=1, step=0.1, label="Choose temperature setting for response generation.")
|
283 |
with gr.Row():
|
284 |
-
model_choice = gr.Radio(label="Choose a chat model", value="Qwen 2 0.5B (small, fast)", choices = ["Qwen 2 0.5B (small, fast)", "Phi 3.5 Mini (larger, slow)"])
|
285 |
change_model_button = gr.Button(value="Load model", scale=0)
|
286 |
with gr.Accordion("Choose number of model layers to send to GPU (WARNING: please don't modify unless you are sure you have a GPU).", open = False):
|
287 |
gpu_layer_choice = gr.Slider(label="Choose number of model layers to send to GPU.", value=0, minimum=0, maximum=100, step = 1, visible=True)
|
288 |
|
289 |
-
load_text = gr.Text(label="Load status")
|
290 |
-
|
291 |
|
292 |
gr.HTML(
|
293 |
"<center>This app is based on the models Qwen 2 0.5B and Phi 3.5 Mini. It powered by Gradio, Transformers, and Llama.cpp.</a></center>"
|
@@ -295,11 +283,41 @@ with app:
|
|
295 |
|
296 |
examples_set.change(fn=chatf.update_message, inputs=[examples_set], outputs=[message])
|
297 |
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
# Load in a pdf
|
305 |
load_pdf_click = load_pdf.click(ing.parse_file, inputs=[in_pdf], outputs=[ingest_text, current_source]).\
|
@@ -318,38 +336,23 @@ with app:
|
|
318 |
success(ing.csv_excel_text_to_docs, inputs=[ingest_text, in_text_column], outputs=[ingest_docs]).\
|
319 |
success(embed_faiss_save_to_zip, inputs=[ingest_docs], outputs=[ingest_embed_out, vectorstore_state, file_out_box]).\
|
320 |
success(chatf.hide_block, outputs = [examples_set])
|
|
|
321 |
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
response_click = submit.click(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages, api_model_choice, in_api_key], outputs=[chat_history_state, sources, instruction_prompt_out, relevant_query_state], queue=False, api_name="retrieval").\
|
326 |
-
success(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
|
327 |
-
success(chatf.produce_streaming_answer_chatbot, inputs=[chatbot, instruction_prompt_out, model_type_state, temp_slide, relevant_query_state], outputs=chatbot)
|
328 |
-
response_click.success(chatf.highlight_found_text, [chatbot, sources], [sources]).\
|
329 |
-
success(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
|
330 |
-
success(lambda: chatf.restore_interactivity(), None, [message], queue=False)
|
331 |
-
|
332 |
-
response_enter = message.submit(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages, api_model_choice, in_api_key], outputs=[chat_history_state, sources, instruction_prompt_out, relevant_query_state], queue=False).\
|
333 |
-
success(chatf.turn_off_interactivity, inputs=[message, chatbot], outputs=[message, chatbot], queue=False).\
|
334 |
-
success(chatf.produce_streaming_answer_chatbot, [chatbot, instruction_prompt_out, model_type_state, temp_slide, relevant_query_state], chatbot)
|
335 |
-
response_enter.success(chatf.highlight_found_text, [chatbot, sources], [sources]).\
|
336 |
-
success(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
|
337 |
-
success(lambda: chatf.restore_interactivity(), None, [message], queue=False)
|
338 |
-
|
339 |
-
# Stop box
|
340 |
-
stop.click(fn=None, inputs=None, outputs=None, cancels=[response_click, response_enter])
|
341 |
-
|
342 |
-
# Clear box
|
343 |
-
clear.click(chatf.clear_chat, inputs=[chat_history_state, sources, message, current_topic], outputs=[chat_history_state, sources, message, current_topic])
|
344 |
-
clear.click(lambda: None, None, chatbot, queue=False)
|
345 |
|
346 |
-
|
347 |
-
|
|
|
|
|
|
|
348 |
|
349 |
###
|
350 |
# LOGGING AND ON APP LOAD FUNCTIONS
|
351 |
###
|
352 |
-
app.load(get_connection_params, inputs=None, outputs=[session_hash_state, s3_output_folder_state, session_hash_textbox])
|
|
|
353 |
|
354 |
# Log usernames and times of access to file (to know who is using the app when running on AWS)
|
355 |
access_callback = gr.CSVLogger()
|
@@ -358,10 +361,6 @@ with app:
|
|
358 |
session_hash_textbox.change(lambda *args: access_callback.flag(list(args)), [session_hash_textbox], None, preprocess=False).\
|
359 |
success(fn = upload_file_to_s3, inputs=[access_logs_state, access_s3_logs_loc_state], outputs=[s3_logs_output_textbox])
|
360 |
|
361 |
-
# Launch the Gradio app
|
362 |
-
COGNITO_AUTH = get_or_create_env_var('COGNITO_AUTH', '0')
|
363 |
-
print(f'The value of COGNITO_AUTH is {COGNITO_AUTH}')
|
364 |
-
|
365 |
if __name__ == "__main__":
|
366 |
if os.environ['COGNITO_AUTH'] == "1":
|
367 |
app.queue().launch(show_error=True, auth=authenticate_user, max_file_size='50mb')
|
|
|
|
|
|
|
1 |
import os
|
|
|
|
|
2 |
from typing import Type
|
3 |
+
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
|
4 |
from langchain_community.vectorstores import FAISS
|
5 |
import gradio as gr
|
6 |
import pandas as pd
|
7 |
|
|
|
|
|
|
|
|
|
|
|
8 |
from chatfuncs.ingest import embed_faiss_save_to_zip
|
|
|
9 |
|
10 |
+
from chatfuncs.helper_functions import ensure_output_folder_exists, get_connection_params, output_folder, reveal_feedback_buttons, wipe_logs
|
11 |
from chatfuncs.aws_functions import upload_file_to_s3
|
|
|
12 |
from chatfuncs.auth import authenticate_user
|
13 |
+
from chatfuncs.config import FEEDBACK_LOGS_FOLDER, ACCESS_LOGS_FOLDER, USAGE_LOGS_FOLDER, HOST_NAME, COGNITO_AUTH
|
14 |
+
|
15 |
+
from llama_cpp import Llama
|
16 |
+
from huggingface_hub import hf_hub_download
|
17 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
|
18 |
+
import os
|
19 |
|
20 |
PandasDataFrame = Type[pd.DataFrame]
|
21 |
|
22 |
from datetime import datetime
|
23 |
today_rev = datetime.now().strftime("%Y%m%d")
|
24 |
|
25 |
+
host_name = HOST_NAME
|
26 |
+
access_logs_data_folder = ACCESS_LOGS_FOLDER
|
27 |
+
feedback_data_folder = FEEDBACK_LOGS_FOLDER
|
28 |
+
usage_data_folder = USAGE_LOGS_FOLDER
|
|
|
|
|
|
|
29 |
|
30 |
# Disable cuda devices if necessary
|
31 |
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
32 |
|
|
|
33 |
import chatfuncs.ingest as ing
|
34 |
|
35 |
###
|
|
|
64 |
return vectorstore
|
65 |
|
66 |
import chatfuncs.chatfuncs as chatf
|
67 |
+
from chatfuncs.model_load import torch_device, gpu_config, cpu_config, context_length
|
68 |
|
69 |
chatf.embeddings = load_embeddings(embeddings_name)
|
70 |
chatf.vectorstore = get_faiss_store(faiss_vstore_folder="faiss_embedding",embeddings=globals()["embeddings"])
|
71 |
|
72 |
+
def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
|
73 |
|
74 |
+
print(f"> Total split documents: {len(docs_out)}")
|
75 |
+
|
76 |
+
print(docs_out)
|
77 |
+
|
78 |
+
vectorstore_func = FAISS.from_documents(documents=docs_out, embedding=embeddings)
|
79 |
+
|
80 |
+
chatf.vectorstore = vectorstore_func
|
81 |
+
|
82 |
+
out_message = "Document processing complete"
|
83 |
+
|
84 |
+
return out_message, vectorstore_func
|
85 |
+
# Gradio chat
|
86 |
+
|
87 |
+
def create_hf_model(model_name:str):
|
88 |
+
if torch_device == "cuda":
|
89 |
+
if "flan" in model_name:
|
90 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto")#, torch_dtype=torch.float16)
|
91 |
+
else:
|
92 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")#, torch_dtype=torch.float16)
|
93 |
+
else:
|
94 |
+
if "flan" in model_name:
|
95 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)#, torch_dtype=torch.float16)
|
96 |
+
else:
|
97 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)#, trust_remote_code=True)#, torch_dtype=torch.float16)
|
98 |
+
|
99 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length = context_length)
|
100 |
|
101 |
+
return model, tokenizer
|
102 |
+
|
103 |
+
def load_model(model_type:str, gpu_layers:int, gpu_config:dict=gpu_config, cpu_config:dict=cpu_config, torch_device:str=torch_device):
|
104 |
+
print("Loading model")
|
|
|
|
|
|
|
105 |
|
106 |
if model_type == "Phi 3.5 Mini (larger, slow)":
|
107 |
if torch_device == "cuda":
|
|
|
126 |
)
|
127 |
|
128 |
except Exception as e:
|
129 |
+
print("GPU load failed", e)
|
|
|
130 |
model = Llama(
|
131 |
model_path=hf_hub_download(
|
132 |
repo_id=os.environ.get("REPO_ID", "QuantFactory/Phi-3.5-mini-instruct-GGUF"), #"QuantFactory/Phi-3-mini-128k-instruct-GGUF"), #, "microsoft/Phi-3-mini-4k-instruct-gguf"),#"QuantFactory/Meta-Llama-3-8B-Instruct-GGUF-v2"), #"microsoft/Phi-3-mini-4k-instruct-gguf"),#"TheBloke/Mistral-7B-OpenOrca-GGUF"),
|
|
|
141 |
# Huggingface chat model
|
142 |
hf_checkpoint = 'Qwen/Qwen2-0.5B-Instruct'# 'declare-lab/flan-alpaca-large'#'declare-lab/flan-alpaca-base' # # # 'Qwen/Qwen1.5-0.5B-Chat' #
|
143 |
|
144 |
+
model, tokenizer = create_hf_model(model_name = hf_checkpoint)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
+
else:
|
147 |
+
model = model_type
|
148 |
+
tokenizer = ""
|
149 |
|
150 |
+
chatf.model_object = model
|
151 |
chatf.tokenizer = tokenizer
|
152 |
chatf.model_type = model_type
|
153 |
|
154 |
load_confirmation = "Finished loading model: " + model_type
|
155 |
|
156 |
print(load_confirmation)
|
|
|
157 |
|
158 |
+
return model_type, load_confirmation, model_type#model, tokenizer, model_type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
|
161 |
###
|
|
|
165 |
app = gr.Blocks(theme = gr.themes.Base(), fill_width=True)#css=".gradio-container {background-color: black}")
|
166 |
|
167 |
with app:
|
168 |
+
model_type = "Qwen 2 0.5B (small, fast)"
|
169 |
+
load_model(model_type, 0, gpu_config, cpu_config, torch_device) # chatf.model_object, chatf.tokenizer, chatf.model_type =
|
170 |
+
|
171 |
+
print("chatf.model_object:", chatf.model_object)
|
172 |
+
|
173 |
+
# Both models are loaded on app initialisation so that users don't have to wait for the models to be downloaded
|
174 |
+
#model_type = "Phi 3.5 Mini (larger, slow)"
|
175 |
+
#load_model(model_type, gpu_layers, gpu_config, cpu_config, torch_device)
|
176 |
+
|
177 |
ingest_text = gr.State()
|
178 |
ingest_metadata = gr.State()
|
179 |
ingest_docs = gr.State()
|
180 |
|
181 |
model_type_state = gr.State(model_type)
|
182 |
+
gpu_config_state = gr.State(gpu_config)
|
183 |
+
cpu_config_state = gr.State(cpu_config)
|
184 |
+
torch_device_state = gr.State(torch_device)
|
185 |
embeddings_state = gr.State(chatf.embeddings)#globals()["embeddings"])
|
186 |
vectorstore_state = gr.State(chatf.vectorstore)#globals()["vectorstore"])
|
187 |
|
188 |
relevant_query_state = gr.Checkbox(value=True, visible=False)
|
189 |
|
190 |
+
model_state = gr.State() # chatf.model_object (gives error)
|
191 |
tokenizer_state = gr.State() # chatf.tokenizer (gives error)
|
192 |
|
193 |
chat_history_state = gr.State()
|
|
|
211 |
gr.Markdown("Chat with PDF, web page or (new) csv/Excel documents. The default is a small model (Qwen 2 0.5B), that can only answer specific questions that are answered in the text. It cannot give overall impressions of, or summarise the document. The alternative (Phi 3.5 Mini (larger, slow)), can reason a little better, but is much slower (See Advanced tab).\n\nBy default the Lambeth Borough Plan '[Lambeth 2030 : Our Future, Our Lambeth](https://www.lambeth.gov.uk/better-fairer-lambeth/projects/lambeth-2030-our-future-our-lambeth)' is loaded. If you want to talk about another document or web page, please select from the second tab. If switching topic, please click the 'Clear chat' button.\n\nCaution: This is a public app. Please ensure that the document you upload is not sensitive is any way as other users may see it! Also, please note that LLM chatbots may give incomplete or incorrect information, so please use with care.")
|
212 |
|
213 |
with gr.Accordion(label="Use Gemini or AWS Claude model", open=False, visible=False):
|
214 |
+
api_model_choice = gr.Dropdown(value = "None", choices = ["gemini-2.0-flash-001", "gemini-2.5-flash-preview-04-17", "gemini-2.5-pro-preview-03-25", "anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0", "None"], label="LLM model to use", multiselect=False, interactive=True, visible=False)
|
215 |
in_api_key = gr.Textbox(value = "", label="Enter Gemini API key (only if using Google API models)", lines=1, type="password",interactive=True, visible=False)
|
216 |
|
217 |
with gr.Row():
|
|
|
222 |
|
223 |
with gr.Row():
|
224 |
#chat_height = 500
|
225 |
+
chatbot = gr.Chatbot(value=None, avatar_images=('user.jfif', 'bot.jpg'), scale = 1, resizable=True, show_copy_all_button=True, show_copy_button=True, show_share_button=True, type='messages') # , height=chat_height
|
226 |
with gr.Accordion("Open this tab to see the source paragraphs used to generate the answer", open = True):
|
227 |
sources = gr.HTML(value = "Source paragraphs with the most relevant text will appear here") # , height=chat_height
|
228 |
|
|
|
270 |
out_passages = gr.Slider(minimum=1, value = 2, maximum=10, step=1, label="Choose number of passages to retrieve from the document. Numbers greater than 2 may lead to increased hallucinations or input text being truncated.")
|
271 |
temp_slide = gr.Slider(minimum=0.1, value = 0.5, maximum=1, step=0.1, label="Choose temperature setting for response generation.")
|
272 |
with gr.Row():
|
273 |
+
model_choice = gr.Radio(label="Choose a chat model", value="Qwen 2 0.5B (small, fast)", choices = ["Qwen 2 0.5B (small, fast)", "Phi 3.5 Mini (larger, slow)", "gemini-2.0-flash-001", "gemini-2.5-flash-preview-04-17", "gemini-2.5-pro-preview-03-25", "anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0"])
|
274 |
change_model_button = gr.Button(value="Load model", scale=0)
|
275 |
with gr.Accordion("Choose number of model layers to send to GPU (WARNING: please don't modify unless you are sure you have a GPU).", open = False):
|
276 |
gpu_layer_choice = gr.Slider(label="Choose number of model layers to send to GPU.", value=0, minimum=0, maximum=100, step = 1, visible=True)
|
277 |
|
278 |
+
load_text = gr.Text(label="Load status")
|
|
|
279 |
|
280 |
gr.HTML(
|
281 |
"<center>This app is based on the models Qwen 2 0.5B and Phi 3.5 Mini. It powered by Gradio, Transformers, and Llama.cpp.</a></center>"
|
|
|
283 |
|
284 |
examples_set.change(fn=chatf.update_message, inputs=[examples_set], outputs=[message])
|
285 |
|
286 |
+
|
287 |
+
###
|
288 |
+
# CHAT PAGE
|
289 |
+
###
|
290 |
+
|
291 |
+
# Click to send message
|
292 |
+
response_click = submit.click(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages, api_model_choice, in_api_key], outputs=[chat_history_state, sources, instruction_prompt_out, relevant_query_state], queue=False, api_name="retrieval").\
|
293 |
+
success(chatf.turn_off_interactivity, inputs=None, outputs=[message, submit], queue=False).\
|
294 |
+
success(chatf.produce_streaming_answer_chatbot, inputs=[chatbot, instruction_prompt_out, model_type_state, temp_slide, relevant_query_state, chat_history_state], outputs=chatbot)
|
295 |
+
response_click.success(chatf.highlight_found_text, [chatbot, sources], [sources]).\
|
296 |
+
success(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
|
297 |
+
success(lambda: chatf.restore_interactivity(), None, [message, submit], queue=False)
|
298 |
+
|
299 |
+
# Press enter to send message
|
300 |
+
response_enter = message.submit(chatf.create_full_prompt, inputs=[message, chat_history_state, current_topic, vectorstore_state, embeddings_state, model_type_state, out_passages, api_model_choice, in_api_key], outputs=[chat_history_state, sources, instruction_prompt_out, relevant_query_state], queue=False).\
|
301 |
+
success(chatf.turn_off_interactivity, inputs=None, outputs=[message, submit], queue=False).\
|
302 |
+
success(chatf.produce_streaming_answer_chatbot, [chatbot, instruction_prompt_out, model_type_state, temp_slide, relevant_query_state, chat_history_state], chatbot)
|
303 |
+
response_enter.success(chatf.highlight_found_text, [chatbot, sources], [sources]).\
|
304 |
+
success(chatf.add_inputs_answer_to_history,[message, chatbot, current_topic], [chat_history_state, current_topic]).\
|
305 |
+
success(lambda: chatf.restore_interactivity(), None, [message, submit], queue=False)
|
306 |
+
|
307 |
+
# Stop box
|
308 |
+
stop.click(fn=None, inputs=None, outputs=None, cancels=[response_click, response_enter])
|
309 |
+
|
310 |
+
# Clear box
|
311 |
+
clear.click(chatf.clear_chat, inputs=[chat_history_state, sources, message, current_topic], outputs=[chat_history_state, sources, message, current_topic])
|
312 |
+
clear.click(lambda: None, None, chatbot, queue=False)
|
313 |
+
|
314 |
+
# Thumbs up or thumbs down voting function
|
315 |
+
chatbot.like(chatf.vote, [chat_history_state, instruction_prompt_out, model_type_state], None)
|
316 |
+
|
317 |
+
|
318 |
+
###
|
319 |
+
# LOAD NEW DATA PAGE
|
320 |
+
###
|
321 |
|
322 |
# Load in a pdf
|
323 |
load_pdf_click = load_pdf.click(ing.parse_file, inputs=[in_pdf], outputs=[ingest_text, current_source]).\
|
|
|
336 |
success(ing.csv_excel_text_to_docs, inputs=[ingest_text, in_text_column], outputs=[ingest_docs]).\
|
337 |
success(embed_faiss_save_to_zip, inputs=[ingest_docs], outputs=[ingest_embed_out, vectorstore_state, file_out_box]).\
|
338 |
success(chatf.hide_block, outputs = [examples_set])
|
339 |
+
|
340 |
|
341 |
+
###
|
342 |
+
# LOAD MODEL PAGE
|
343 |
+
###
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
|
345 |
+
change_model_button.click(fn=chatf.turn_off_interactivity, inputs=None, outputs=[message, submit], queue=False).\
|
346 |
+
success(fn=load_model, inputs=[model_choice, gpu_layer_choice], outputs = [model_type_state, load_text, current_model]).\
|
347 |
+
success(lambda: chatf.restore_interactivity(), None, [message, submit], queue=False).\
|
348 |
+
success(chatf.clear_chat, inputs=[chat_history_state, sources, message, current_topic], outputs=[chat_history_state, sources, message, current_topic]).\
|
349 |
+
success(lambda: None, None, chatbot, queue=False)
|
350 |
|
351 |
###
|
352 |
# LOGGING AND ON APP LOAD FUNCTIONS
|
353 |
###
|
354 |
+
app.load(get_connection_params, inputs=None, outputs=[session_hash_state, s3_output_folder_state, session_hash_textbox]).\
|
355 |
+
success(load_model, inputs=[model_type_state, gpu_layer_choice, gpu_config_state, cpu_config_state, torch_device_state], outputs=[model_type_state, load_text, current_model])
|
356 |
|
357 |
# Log usernames and times of access to file (to know who is using the app when running on AWS)
|
358 |
access_callback = gr.CSVLogger()
|
|
|
361 |
session_hash_textbox.change(lambda *args: access_callback.flag(list(args)), [session_hash_textbox], None, preprocess=False).\
|
362 |
success(fn = upload_file_to_s3, inputs=[access_logs_state, access_s3_logs_loc_state], outputs=[s3_logs_output_textbox])
|
363 |
|
|
|
|
|
|
|
|
|
364 |
if __name__ == "__main__":
|
365 |
if os.environ['COGNITO_AUTH'] == "1":
|
366 |
app.queue().launch(show_error=True, auth=authenticate_user, max_file_size='50mb')
|
chatfuncs/auth.py
CHANGED
@@ -1,14 +1,22 @@
|
|
1 |
-
|
2 |
import boto3
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
|
|
7 |
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
def authenticate_user(username, password, user_pool_id=
|
12 |
"""Authenticates a user against an AWS Cognito user pool.
|
13 |
|
14 |
Args:
|
@@ -16,22 +24,39 @@ def authenticate_user(username, password, user_pool_id=user_pool_id, client_id=c
|
|
16 |
client_id (str): The ID of the Cognito user pool client.
|
17 |
username (str): The username of the user.
|
18 |
password (str): The password of the user.
|
|
|
19 |
|
20 |
Returns:
|
21 |
bool: True if the user is authenticated, False otherwise.
|
22 |
"""
|
23 |
|
24 |
-
client = boto3.client('cognito-idp') # Cognito Identity Provider client
|
|
|
|
|
|
|
25 |
|
26 |
try:
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
AuthFlow='USER_PASSWORD_AUTH',
|
29 |
AuthParameters={
|
30 |
'USERNAME': username,
|
31 |
'PASSWORD': password,
|
|
|
32 |
},
|
33 |
ClientId=client_id
|
34 |
-
|
35 |
|
36 |
# If successful, you'll receive an AuthenticationResult in the response
|
37 |
if response.get('AuthenticationResult'):
|
@@ -44,5 +69,7 @@ def authenticate_user(username, password, user_pool_id=user_pool_id, client_id=c
|
|
44 |
except client.exceptions.UserNotFoundException:
|
45 |
return False
|
46 |
except Exception as e:
|
47 |
-
|
48 |
-
|
|
|
|
|
|
1 |
+
#import os
|
2 |
import boto3
|
3 |
+
#import gradio as gr
|
4 |
+
import hmac
|
5 |
+
import hashlib
|
6 |
+
import base64
|
7 |
+
from chatfuncs.config import AWS_CLIENT_ID, AWS_CLIENT_SECRET, AWS_USER_POOL_ID, AWS_REGION
|
8 |
|
9 |
+
def calculate_secret_hash(client_id:str, client_secret:str, username:str):
|
10 |
+
message = username + client_id
|
11 |
+
dig = hmac.new(
|
12 |
+
str(client_secret).encode('utf-8'),
|
13 |
+
msg=str(message).encode('utf-8'),
|
14 |
+
digestmod=hashlib.sha256
|
15 |
+
).digest()
|
16 |
+
secret_hash = base64.b64encode(dig).decode()
|
17 |
+
return secret_hash
|
18 |
|
19 |
+
def authenticate_user(username:str, password:str, user_pool_id:str=AWS_USER_POOL_ID, client_id:str=AWS_CLIENT_ID, client_secret:str=AWS_CLIENT_SECRET):
|
20 |
"""Authenticates a user against an AWS Cognito user pool.
|
21 |
|
22 |
Args:
|
|
|
24 |
client_id (str): The ID of the Cognito user pool client.
|
25 |
username (str): The username of the user.
|
26 |
password (str): The password of the user.
|
27 |
+
client_secret (str): The client secret of the app client
|
28 |
|
29 |
Returns:
|
30 |
bool: True if the user is authenticated, False otherwise.
|
31 |
"""
|
32 |
|
33 |
+
client = boto3.client('cognito-idp', region_name=AWS_REGION) # Cognito Identity Provider client
|
34 |
+
|
35 |
+
# Compute the secret hash
|
36 |
+
secret_hash = calculate_secret_hash(client_id, client_secret, username)
|
37 |
|
38 |
try:
|
39 |
+
|
40 |
+
if client_secret == '':
|
41 |
+
response = client.initiate_auth(
|
42 |
+
AuthFlow='USER_PASSWORD_AUTH',
|
43 |
+
AuthParameters={
|
44 |
+
'USERNAME': username,
|
45 |
+
'PASSWORD': password,
|
46 |
+
},
|
47 |
+
ClientId=client_id
|
48 |
+
)
|
49 |
+
|
50 |
+
else:
|
51 |
+
response = client.initiate_auth(
|
52 |
AuthFlow='USER_PASSWORD_AUTH',
|
53 |
AuthParameters={
|
54 |
'USERNAME': username,
|
55 |
'PASSWORD': password,
|
56 |
+
'SECRET_HASH': secret_hash
|
57 |
},
|
58 |
ClientId=client_id
|
59 |
+
)
|
60 |
|
61 |
# If successful, you'll receive an AuthenticationResult in the response
|
62 |
if response.get('AuthenticationResult'):
|
|
|
69 |
except client.exceptions.UserNotFoundException:
|
70 |
return False
|
71 |
except Exception as e:
|
72 |
+
out_message = f"An error occurred: {e}"
|
73 |
+
print(out_message)
|
74 |
+
raise Exception(out_message)
|
75 |
+
return False
|
chatfuncs/chatfuncs.py
CHANGED
@@ -6,11 +6,16 @@ import time
|
|
6 |
from itertools import compress
|
7 |
import pandas as pd
|
8 |
import numpy as np
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Model packages
|
11 |
import torch.cuda
|
12 |
from threading import Thread
|
13 |
from transformers import pipeline, TextIteratorStreamer
|
|
|
14 |
|
15 |
# Alternative model sources
|
16 |
#from dataclasses import asdict, dataclass
|
@@ -22,31 +27,37 @@ from langchain_community.retrievers import SVMRetriever
|
|
22 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
23 |
from langchain.docstore.document import Document
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
# For keyword extraction (not currently used)
|
26 |
#import nltk
|
27 |
#nltk.download('wordnet')
|
28 |
from nltk.corpus import stopwords
|
29 |
from nltk.tokenize import RegexpTokenizer
|
30 |
from nltk.stem import WordNetLemmatizer
|
31 |
-
#from nltk.stem.snowball import SnowballStemmer
|
32 |
from keybert import KeyBERT
|
33 |
|
34 |
# For Name Entity Recognition model
|
35 |
#from span_marker import SpanMarkerModel # Not currently used
|
36 |
|
37 |
-
|
38 |
# For BM25 retrieval
|
39 |
import bm25s
|
40 |
import Stemmer
|
41 |
|
42 |
-
|
43 |
-
#from gensim.models import TfidfModel, OkapiBM25Model
|
44 |
-
#from gensim.similarities import SparseMatrixSimilarity
|
45 |
-
|
46 |
-
from llama_cpp import Llama
|
47 |
-
from huggingface_hub import hf_hub_download
|
48 |
-
|
49 |
-
from chatfuncs.prompts import instruction_prompt_template_alpaca, instruction_prompt_mistral_orca, instruction_prompt_phi3, instruction_prompt_llama3, instruction_prompt_qwen
|
50 |
|
51 |
import gradio as gr
|
52 |
|
@@ -60,10 +71,8 @@ model_type = None # global variable setup
|
|
60 |
|
61 |
max_memory_length = 0 # How long should the memory of the conversation last?
|
62 |
|
63 |
-
|
64 |
|
65 |
-
model = [] # Define empty list for model functions to run
|
66 |
-
tokenizer = [] # Define empty list for model functions to run
|
67 |
|
68 |
## Highlight text constants
|
69 |
hlt_chunk_size = 12
|
@@ -77,84 +86,6 @@ ner_model = []#SpanMarkerModel.from_pretrained("tomaarsen/span-marker-mbert-base
|
|
77 |
# Used to pull out keywords from chat history to add to user queries behind the scenes
|
78 |
kw_model = pipeline("feature-extraction", model="sentence-transformers/all-MiniLM-L6-v2")
|
79 |
|
80 |
-
# Currently set gpu_layers to 0 even with cuda due to persistent bugs in implementation with cuda
|
81 |
-
if torch.cuda.is_available():
|
82 |
-
torch_device = "cuda"
|
83 |
-
gpu_layers = 100
|
84 |
-
else:
|
85 |
-
torch_device = "cpu"
|
86 |
-
gpu_layers = 0
|
87 |
-
|
88 |
-
print("Running on device:", torch_device)
|
89 |
-
threads = 8 #torch.get_num_threads()
|
90 |
-
print("CPU threads:", threads)
|
91 |
-
|
92 |
-
# Qwen 2 0.5B (small, fast) Model parameters
|
93 |
-
temperature: float = 0.1
|
94 |
-
top_k: int = 3
|
95 |
-
top_p: float = 1
|
96 |
-
repetition_penalty: float = 1.15
|
97 |
-
flan_alpaca_repetition_penalty: float = 1.3
|
98 |
-
last_n_tokens: int = 64
|
99 |
-
max_new_tokens: int = 1024
|
100 |
-
seed: int = 42
|
101 |
-
reset: bool = False
|
102 |
-
stream: bool = True
|
103 |
-
threads: int = threads
|
104 |
-
batch_size:int = 256
|
105 |
-
context_length:int = 2048
|
106 |
-
sample = True
|
107 |
-
|
108 |
-
|
109 |
-
class CtransInitConfig_gpu:
|
110 |
-
def __init__(self,
|
111 |
-
last_n_tokens=last_n_tokens,
|
112 |
-
seed=seed,
|
113 |
-
n_threads=threads,
|
114 |
-
n_batch=batch_size,
|
115 |
-
n_ctx=4096,
|
116 |
-
n_gpu_layers=gpu_layers):
|
117 |
-
|
118 |
-
self.last_n_tokens = last_n_tokens
|
119 |
-
self.seed = seed
|
120 |
-
self.n_threads = n_threads
|
121 |
-
self.n_batch = n_batch
|
122 |
-
self.n_ctx = n_ctx
|
123 |
-
self.n_gpu_layers = n_gpu_layers
|
124 |
-
# self.stop: list[str] = field(default_factory=lambda: [stop_string])
|
125 |
-
|
126 |
-
def update_gpu(self, new_value):
|
127 |
-
self.n_gpu_layers = new_value
|
128 |
-
|
129 |
-
class CtransInitConfig_cpu(CtransInitConfig_gpu):
|
130 |
-
def __init__(self):
|
131 |
-
super().__init__()
|
132 |
-
self.n_gpu_layers = 0
|
133 |
-
|
134 |
-
gpu_config = CtransInitConfig_gpu()
|
135 |
-
cpu_config = CtransInitConfig_cpu()
|
136 |
-
|
137 |
-
|
138 |
-
class CtransGenGenerationConfig:
|
139 |
-
def __init__(self, temperature=temperature,
|
140 |
-
top_k=top_k,
|
141 |
-
top_p=top_p,
|
142 |
-
repeat_penalty=repetition_penalty,
|
143 |
-
seed=seed,
|
144 |
-
stream=stream,
|
145 |
-
max_tokens=max_new_tokens
|
146 |
-
):
|
147 |
-
self.temperature = temperature
|
148 |
-
self.top_k = top_k
|
149 |
-
self.top_p = top_p
|
150 |
-
self.repeat_penalty = repeat_penalty
|
151 |
-
self.seed = seed
|
152 |
-
self.max_tokens=max_tokens
|
153 |
-
self.stream = stream
|
154 |
-
|
155 |
-
def update_temp(self, new_value):
|
156 |
-
self.temperature = new_value
|
157 |
-
|
158 |
# Vectorstore funcs
|
159 |
|
160 |
def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
|
@@ -187,7 +118,7 @@ def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
|
|
187 |
|
188 |
# Prompt functions
|
189 |
|
190 |
-
def base_prompt_templates(model_type = "Qwen 2 0.5B (small, fast)"):
|
191 |
|
192 |
#EXAMPLE_PROMPT = PromptTemplate(
|
193 |
# template="\nCONTENT:\n\n{page_content}\n\nSOURCE: {source}\n\n",
|
@@ -205,20 +136,20 @@ def base_prompt_templates(model_type = "Qwen 2 0.5B (small, fast)"):
|
|
205 |
INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_qwen, input_variables=['question', 'summaries'])
|
206 |
elif model_type == "Phi 3.5 Mini (larger, slow)":
|
207 |
INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_phi3, input_variables=['question', 'summaries'])
|
|
|
|
|
|
|
208 |
|
209 |
return INSTRUCTION_PROMPT, CONTENT_PROMPT
|
210 |
|
211 |
-
def write_out_metadata_as_string(metadata_in):
|
212 |
metadata_string = [f"{' '.join(f'{k}: {v}' for k, v in d.items() if k != 'page_section')}" for d in metadata_in] # ['metadata']
|
213 |
return metadata_string
|
214 |
|
215 |
-
def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt, content_prompt, extracted_memory, vectorstore, embeddings, relevant_flag = True, out_passages = 2): # ,
|
216 |
|
217 |
question = inputs["question"]
|
218 |
chat_history = inputs["chat_history"]
|
219 |
-
|
220 |
-
print("relevant_flag in generate_expanded_prompt:", relevant_flag)
|
221 |
-
|
222 |
|
223 |
if relevant_flag == True:
|
224 |
new_question_kworded = adapt_q_from_chat_history(question, chat_history, extracted_memory) # new_question_keywords,
|
@@ -234,8 +165,6 @@ def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt, content
|
|
234 |
return sorry_prompt, "No relevant sources found.", new_question_kworded
|
235 |
|
236 |
# Expand the found passages to the neighbouring context
|
237 |
-
print("Doc_df columns:", doc_df.columns)
|
238 |
-
|
239 |
if 'meta_url' in doc_df.columns:
|
240 |
file_type = determine_file_type(doc_df['meta_url'][0])
|
241 |
else:
|
@@ -265,14 +194,22 @@ def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt, content
|
|
265 |
|
266 |
return instruction_prompt_out, sources_docs_content_string, new_question_kworded
|
267 |
|
268 |
-
def create_full_prompt(user_input,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
|
270 |
#if chain_agent is None:
|
271 |
# history.append((user_input, "Please click the button to submit the Huggingface API key before using the chatbot (top right)"))
|
272 |
# return history, history, "", ""
|
273 |
print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
|
274 |
-
|
275 |
-
|
276 |
history = history or []
|
277 |
|
278 |
if api_model_choice and api_model_choice != "None":
|
@@ -298,31 +235,13 @@ def create_full_prompt(user_input, history, extracted_memory, vectorstore, embed
|
|
298 |
generate_expanded_prompt({"question": user_input, "chat_history": history}, #vectorstore,
|
299 |
instruction_prompt, content_prompt, extracted_memory, vectorstore, embeddings, relevant_flag, out_passages)
|
300 |
|
301 |
-
history.append(user_input)
|
302 |
|
303 |
print("Output history is:", history)
|
304 |
print("Final prompt to model is:",instruction_prompt_out)
|
305 |
|
306 |
return history, docs_content_string, instruction_prompt_out, relevant_flag
|
307 |
|
308 |
-
# Chat functions
|
309 |
-
import boto3
|
310 |
-
import json
|
311 |
-
from chatfuncs.helper_functions import get_or_create_env_var
|
312 |
-
|
313 |
-
# ResponseObject class for AWS Bedrock calls
|
314 |
-
class ResponseObject:
|
315 |
-
def __init__(self, text, usage_metadata):
|
316 |
-
self.text = text
|
317 |
-
self.usage_metadata = usage_metadata
|
318 |
-
|
319 |
-
max_tokens = 4096
|
320 |
-
|
321 |
-
AWS_DEFAULT_REGION = get_or_create_env_var('AWS_DEFAULT_REGION', 'eu-west-2')
|
322 |
-
print(f'The value of AWS_DEFAULT_REGION is {AWS_DEFAULT_REGION}')
|
323 |
-
|
324 |
-
bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_DEFAULT_REGION)
|
325 |
-
|
326 |
def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tokens: int, model_choice: str) -> ResponseObject:
|
327 |
"""
|
328 |
This function sends a request to AWS Claude with the following parameters:
|
@@ -351,6 +270,8 @@ def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tok
|
|
351 |
],
|
352 |
}
|
353 |
|
|
|
|
|
354 |
body = json.dumps(prompt_config)
|
355 |
|
356 |
modelId = model_choice
|
@@ -376,16 +297,173 @@ def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tok
|
|
376 |
|
377 |
return response
|
378 |
|
379 |
-
def
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
):
|
390 |
#print("Model type is: ", model_type)
|
391 |
|
@@ -395,16 +473,19 @@ def produce_streaming_answer_chatbot(history,
|
|
395 |
|
396 |
# return history
|
397 |
|
398 |
-
|
|
|
|
|
399 |
|
400 |
if relevant_query_bool == False:
|
401 |
-
|
402 |
-
history.append(out_message)
|
403 |
|
404 |
yield history
|
405 |
return
|
406 |
|
407 |
if model_type == "Qwen 2 0.5B (small, fast)":
|
|
|
|
|
408 |
# Get the model and tokenizer, and tokenize the user text.
|
409 |
model_inputs = tokenizer(text=full_prompt, return_tensors="pt", return_attention_mask=False).to(torch_device)
|
410 |
|
@@ -422,9 +503,9 @@ def produce_streaming_answer_chatbot(history,
|
|
422 |
top_k=top_k
|
423 |
)
|
424 |
|
425 |
-
|
426 |
|
427 |
-
t = Thread(target=
|
428 |
t.start()
|
429 |
|
430 |
# Pull the generated text from the streamer, and update the model output.
|
@@ -432,12 +513,14 @@ def produce_streaming_answer_chatbot(history,
|
|
432 |
NUM_TOKENS=0
|
433 |
print('-'*4+'Start Generation'+'-'*4)
|
434 |
|
435 |
-
history
|
|
|
436 |
for new_text in streamer:
|
437 |
try:
|
438 |
-
if new_text
|
439 |
-
|
440 |
-
|
|
|
441 |
yield history
|
442 |
except Exception as e:
|
443 |
print(f"Error during text generation: {e}")
|
@@ -463,14 +546,15 @@ def produce_streaming_answer_chatbot(history,
|
|
463 |
NUM_TOKENS=0
|
464 |
print('-'*4+'Start Generation'+'-'*4)
|
465 |
|
466 |
-
output =
|
467 |
full_prompt, **vars(gen_config))
|
468 |
|
469 |
-
history
|
|
|
470 |
for out in output:
|
471 |
|
472 |
if "choices" in out and len(out["choices"]) > 0 and "text" in out["choices"][0]:
|
473 |
-
history[-1][
|
474 |
NUM_TOKENS+=1
|
475 |
yield history
|
476 |
else:
|
@@ -481,36 +565,75 @@ def produce_streaming_answer_chatbot(history,
|
|
481 |
print('-'*4+'End Generation'+'-'*4)
|
482 |
print(f'Num of generated tokens: {NUM_TOKENS}')
|
483 |
print(f'Time for complete generation: {time_generate}s')
|
484 |
-
print(f'Tokens per
|
485 |
print(f'Time per token: {(time_generate/NUM_TOKENS)*1000}ms')
|
486 |
|
487 |
-
elif
|
488 |
system_prompt = "You are answering questions from the user based on source material. Respond with short, factually correct answers."
|
489 |
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
|
500 |
-
time.sleep(30)
|
501 |
-
response = call_aws_claude(full_prompt, system_prompt, temperature, max_tokens, model_type)
|
502 |
-
|
503 |
-
except Exception as e:
|
504 |
-
print(e)
|
505 |
-
return "", history
|
506 |
# Update the conversation history with the new prompt and response
|
507 |
-
|
508 |
-
|
509 |
|
510 |
-
|
511 |
-
#print("conversation_history:", conversation_history)
|
512 |
|
513 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
514 |
|
515 |
# Chat helper functions
|
516 |
|
@@ -589,9 +712,6 @@ def hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val, out_p
|
|
589 |
|
590 |
docs = vectorstore.similarity_search_with_score(new_question_kworded, k=k_val)
|
591 |
|
592 |
-
print("Docs from similarity search:")
|
593 |
-
print(docs)
|
594 |
-
|
595 |
# Keep only documents with a certain score
|
596 |
docs_len = [len(x[0].page_content) for x in docs]
|
597 |
docs_scores = [x[1] for x in docs]
|
@@ -688,12 +808,8 @@ def hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val, out_p
|
|
688 |
# 3rd level check on retrieved docs with SVM retriever
|
689 |
# Check the type of the embeddings object
|
690 |
embeddings_type = type(embeddings)
|
691 |
-
print("Type of embeddings object:", embeddings_type)
|
692 |
-
|
693 |
|
694 |
-
print("embeddings:", embeddings)
|
695 |
|
696 |
-
from langchain_huggingface import HuggingFaceEmbeddings
|
697 |
#hf_embeddings = HuggingFaceEmbeddings(**embeddings)
|
698 |
hf_embeddings = embeddings
|
699 |
|
@@ -743,10 +859,6 @@ def hybrid_retrieval(new_question_kworded, vectorstore, embeddings, k_val, out_p
|
|
743 |
# Make df of best options
|
744 |
doc_df = create_doc_df(docs_keep_out)
|
745 |
|
746 |
-
print("doc_df:",doc_df)
|
747 |
-
print("docs_keep_as_doc:",docs_keep_as_doc)
|
748 |
-
print("docs_keep_out:", docs_keep_out)
|
749 |
-
|
750 |
return docs_keep_as_doc, doc_df, docs_keep_out
|
751 |
|
752 |
def get_expanded_passages(vectorstore, docs, width):
|
@@ -836,16 +948,16 @@ def get_expanded_passages(vectorstore, docs, width):
|
|
836 |
|
837 |
return expanded_docs, doc_df
|
838 |
|
839 |
-
def highlight_found_text(
|
840 |
"""
|
841 |
-
Highlights occurrences of
|
842 |
|
843 |
Parameters:
|
844 |
-
-
|
845 |
-
-
|
846 |
|
847 |
Returns:
|
848 |
-
- str: A string with occurrences of
|
849 |
|
850 |
Example:
|
851 |
>>> highlight_found_text("world", "Hello, world! This is a test. Another world awaits.")
|
@@ -859,32 +971,31 @@ def highlight_found_text(search_text: str, full_text: str, hlt_chunk_size:int=hl
|
|
859 |
return text[i][0].replace(" ", " ").strip()
|
860 |
else:
|
861 |
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
862 |
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
elif isinstance(text, list):
|
867 |
-
return text[-1][1].replace(" ", " ").strip()
|
868 |
-
else:
|
869 |
-
return ""
|
870 |
-
|
871 |
-
full_text = extract_text_from_input(full_text)
|
872 |
-
search_text = extract_search_text_from_input(search_text)
|
873 |
-
|
874 |
|
|
|
875 |
|
876 |
text_splitter = RecursiveCharacterTextSplitter(
|
877 |
chunk_size=hlt_chunk_size,
|
878 |
separators=hlt_strat,
|
879 |
chunk_overlap=hlt_overlap,
|
880 |
)
|
881 |
-
sections = text_splitter.split_text(
|
882 |
|
883 |
found_positions = {}
|
884 |
for x in sections:
|
885 |
text_start_pos = 0
|
886 |
while text_start_pos != -1:
|
887 |
-
text_start_pos =
|
888 |
if text_start_pos != -1:
|
889 |
found_positions[text_start_pos] = text_start_pos + len(x)
|
890 |
text_start_pos += 1
|
@@ -907,20 +1018,24 @@ def highlight_found_text(search_text: str, full_text: str, hlt_chunk_size:int=hl
|
|
907 |
prev_end = 0
|
908 |
for start, end in combined_positions:
|
909 |
if end-start > 15: # Only combine if there is a significant amount of matched text. Avoids picking up single words like 'and' etc.
|
910 |
-
pos_tokens.append(
|
911 |
-
pos_tokens.append('<mark style="color:black;">' +
|
912 |
prev_end = end
|
913 |
-
pos_tokens.append(
|
|
|
|
|
914 |
|
915 |
-
|
|
|
|
|
916 |
|
917 |
|
918 |
# # Chat history functions
|
919 |
|
920 |
def clear_chat(chat_history_state, sources, chat_message, current_topic):
|
921 |
-
chat_history_state =
|
922 |
sources = ''
|
923 |
-
chat_message =
|
924 |
current_topic = ''
|
925 |
|
926 |
return chat_history_state, sources, chat_message, current_topic
|
@@ -1011,8 +1126,7 @@ def remove_q_stopwords(question): # Remove stopwords from question. Not used at
|
|
1011 |
for word in tokens_without_sw:
|
1012 |
if word not in ordered_tokens:
|
1013 |
ordered_tokens.add(word)
|
1014 |
-
result.append(word)
|
1015 |
-
|
1016 |
|
1017 |
|
1018 |
new_question_keywords = ' '.join(result)
|
@@ -1021,9 +1135,6 @@ def remove_q_stopwords(question): # Remove stopwords from question. Not used at
|
|
1021 |
def remove_q_ner_extractor(question):
|
1022 |
|
1023 |
predict_out = ner_model.predict(question)
|
1024 |
-
|
1025 |
-
|
1026 |
-
|
1027 |
predict_tokens = [' '.join(v for k, v in d.items() if k == 'span') for d in predict_out]
|
1028 |
|
1029 |
# Remove duplicate words while preserving order
|
@@ -1075,11 +1186,11 @@ def keybert_keywords(text, n, kw_model):
|
|
1075 |
return keywords_list
|
1076 |
|
1077 |
# Gradio functions
|
1078 |
-
def turn_off_interactivity(
|
1079 |
-
return gr.
|
1080 |
|
1081 |
def restore_interactivity():
|
1082 |
-
return gr.
|
1083 |
|
1084 |
def update_message(dropdown_value):
|
1085 |
return gr.Textbox(value=dropdown_value)
|
|
|
6 |
from itertools import compress
|
7 |
import pandas as pd
|
8 |
import numpy as np
|
9 |
+
import google.generativeai as ai
|
10 |
+
from gradio import Progress
|
11 |
+
import boto3
|
12 |
+
import json
|
13 |
|
14 |
# Model packages
|
15 |
import torch.cuda
|
16 |
from threading import Thread
|
17 |
from transformers import pipeline, TextIteratorStreamer
|
18 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
19 |
|
20 |
# Alternative model sources
|
21 |
#from dataclasses import asdict, dataclass
|
|
|
27 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
28 |
from langchain.docstore.document import Document
|
29 |
|
30 |
+
from chatfuncs.config import GEMINI_API_KEY, AWS_DEFAULT_REGION
|
31 |
+
|
32 |
+
model_object = [] # Define empty list for model functions to run
|
33 |
+
tokenizer = [] # Define empty list for model functions to run
|
34 |
+
|
35 |
+
from chatfuncs.model_load import temperature, max_new_tokens, sample, repetition_penalty, top_p, top_k, torch_device, CtransGenGenerationConfig, max_tokens
|
36 |
+
|
37 |
+
# ResponseObject class for AWS Bedrock calls
|
38 |
+
class ResponseObject:
|
39 |
+
def __init__(self, text, usage_metadata):
|
40 |
+
self.text = text
|
41 |
+
self.usage_metadata = usage_metadata
|
42 |
+
|
43 |
+
bedrock_runtime = boto3.client('bedrock-runtime', region_name=AWS_DEFAULT_REGION)
|
44 |
+
|
45 |
# For keyword extraction (not currently used)
|
46 |
#import nltk
|
47 |
#nltk.download('wordnet')
|
48 |
from nltk.corpus import stopwords
|
49 |
from nltk.tokenize import RegexpTokenizer
|
50 |
from nltk.stem import WordNetLemmatizer
|
|
|
51 |
from keybert import KeyBERT
|
52 |
|
53 |
# For Name Entity Recognition model
|
54 |
#from span_marker import SpanMarkerModel # Not currently used
|
55 |
|
|
|
56 |
# For BM25 retrieval
|
57 |
import bm25s
|
58 |
import Stemmer
|
59 |
|
60 |
+
from chatfuncs.prompts import instruction_prompt_template_alpaca, instruction_prompt_mistral_orca, instruction_prompt_phi3, instruction_prompt_llama3, instruction_prompt_qwen, instruction_prompt_template_orca
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
import gradio as gr
|
63 |
|
|
|
71 |
|
72 |
max_memory_length = 0 # How long should the memory of the conversation last?
|
73 |
|
74 |
+
source_texts = "" # Define dummy source text (full text) just to enable highlight function to load
|
75 |
|
|
|
|
|
76 |
|
77 |
## Highlight text constants
|
78 |
hlt_chunk_size = 12
|
|
|
86 |
# Used to pull out keywords from chat history to add to user queries behind the scenes
|
87 |
kw_model = pipeline("feature-extraction", model="sentence-transformers/all-MiniLM-L6-v2")
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
# Vectorstore funcs
|
90 |
|
91 |
def docs_to_faiss_save(docs_out:PandasDataFrame, embeddings=embeddings):
|
|
|
118 |
|
119 |
# Prompt functions
|
120 |
|
121 |
+
def base_prompt_templates(model_type:str = "Qwen 2 0.5B (small, fast)"):
|
122 |
|
123 |
#EXAMPLE_PROMPT = PromptTemplate(
|
124 |
# template="\nCONTENT:\n\n{page_content}\n\nSOURCE: {source}\n\n",
|
|
|
136 |
INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_qwen, input_variables=['question', 'summaries'])
|
137 |
elif model_type == "Phi 3.5 Mini (larger, slow)":
|
138 |
INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_phi3, input_variables=['question', 'summaries'])
|
139 |
+
else:
|
140 |
+
INSTRUCTION_PROMPT=PromptTemplate(template=instruction_prompt_template_orca, input_variables=['question', 'summaries'])
|
141 |
+
|
142 |
|
143 |
return INSTRUCTION_PROMPT, CONTENT_PROMPT
|
144 |
|
145 |
+
def write_out_metadata_as_string(metadata_in:str):
|
146 |
metadata_string = [f"{' '.join(f'{k}: {v}' for k, v in d.items() if k != 'page_section')}" for d in metadata_in] # ['metadata']
|
147 |
return metadata_string
|
148 |
|
149 |
+
def generate_expanded_prompt(inputs: Dict[str, str], instruction_prompt:str, content_prompt:str, extracted_memory:list, vectorstore:object, embeddings:object, relevant_flag:bool = True, out_passages:int = 2): # ,
|
150 |
|
151 |
question = inputs["question"]
|
152 |
chat_history = inputs["chat_history"]
|
|
|
|
|
|
|
153 |
|
154 |
if relevant_flag == True:
|
155 |
new_question_kworded = adapt_q_from_chat_history(question, chat_history, extracted_memory) # new_question_keywords,
|
|
|
165 |
return sorry_prompt, "No relevant sources found.", new_question_kworded
|
166 |
|
167 |
# Expand the found passages to the neighbouring context
|
|
|
|
|
168 |
if 'meta_url' in doc_df.columns:
|
169 |
file_type = determine_file_type(doc_df['meta_url'][0])
|
170 |
else:
|
|
|
194 |
|
195 |
return instruction_prompt_out, sources_docs_content_string, new_question_kworded
|
196 |
|
197 |
+
def create_full_prompt(user_input:str,
|
198 |
+
history:list[dict],
|
199 |
+
extracted_memory:str,
|
200 |
+
vectorstore:object,
|
201 |
+
embeddings:object,
|
202 |
+
model_type:str,
|
203 |
+
out_passages:list[str],
|
204 |
+
api_model_choice=None,
|
205 |
+
api_key=None,
|
206 |
+
relevant_flag = True):
|
207 |
|
208 |
#if chain_agent is None:
|
209 |
# history.append((user_input, "Please click the button to submit the Huggingface API key before using the chatbot (top right)"))
|
210 |
# return history, history, "", ""
|
211 |
print("\n==== date/time: " + str(datetime.datetime.now()) + " ====")
|
212 |
+
|
|
|
213 |
history = history or []
|
214 |
|
215 |
if api_model_choice and api_model_choice != "None":
|
|
|
235 |
generate_expanded_prompt({"question": user_input, "chat_history": history}, #vectorstore,
|
236 |
instruction_prompt, content_prompt, extracted_memory, vectorstore, embeddings, relevant_flag, out_passages)
|
237 |
|
238 |
+
history.append({"metadata":None, "options":None, "role": 'user', "content": user_input})
|
239 |
|
240 |
print("Output history is:", history)
|
241 |
print("Final prompt to model is:",instruction_prompt_out)
|
242 |
|
243 |
return history, docs_content_string, instruction_prompt_out, relevant_flag
|
244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
def call_aws_claude(prompt: str, system_prompt: str, temperature: float, max_tokens: int, model_choice: str) -> ResponseObject:
|
246 |
"""
|
247 |
This function sends a request to AWS Claude with the following parameters:
|
|
|
270 |
],
|
271 |
}
|
272 |
|
273 |
+
print("prompt_config:", prompt_config)
|
274 |
+
|
275 |
body = json.dumps(prompt_config)
|
276 |
|
277 |
modelId = model_choice
|
|
|
297 |
|
298 |
return response
|
299 |
|
300 |
+
def construct_gemini_generative_model(in_api_key: str, temperature: float, model_choice: str, system_prompt: str, max_tokens: int) -> Tuple[object, dict]:
|
301 |
+
"""
|
302 |
+
Constructs a GenerativeModel for Gemini API calls.
|
303 |
+
|
304 |
+
Parameters:
|
305 |
+
- in_api_key (str): The API key for authentication.
|
306 |
+
- temperature (float): The temperature parameter for the model, controlling the randomness of the output.
|
307 |
+
- model_choice (str): The choice of model to use for generation.
|
308 |
+
- system_prompt (str): The system prompt to guide the generation.
|
309 |
+
- max_tokens (int): The maximum number of tokens to generate.
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
- Tuple[object, dict]: A tuple containing the constructed GenerativeModel and its configuration.
|
313 |
+
"""
|
314 |
+
# Construct a GenerativeModel
|
315 |
+
try:
|
316 |
+
if in_api_key:
|
317 |
+
#print("Getting API key from textbox")
|
318 |
+
api_key = in_api_key
|
319 |
+
ai.configure(api_key=api_key)
|
320 |
+
elif "GOOGLE_API_KEY" in os.environ:
|
321 |
+
#print("Searching for API key in environmental variables")
|
322 |
+
api_key = os.environ["GOOGLE_API_KEY"]
|
323 |
+
ai.configure(api_key=api_key)
|
324 |
+
else:
|
325 |
+
print("No API key foound")
|
326 |
+
raise gr.Error("No API key found.")
|
327 |
+
except Exception as e:
|
328 |
+
print(e)
|
329 |
+
|
330 |
+
config = ai.GenerationConfig(temperature=temperature, max_output_tokens=max_tokens)
|
331 |
+
|
332 |
+
print("model_choice:", model_choice)
|
333 |
+
|
334 |
+
#model = ai.GenerativeModel.from_cached_content(cached_content=cache, generation_config=config)
|
335 |
+
model = ai.GenerativeModel(model_name=model_choice, system_instruction=system_prompt, generation_config=config)
|
336 |
+
|
337 |
+
return model, config
|
338 |
+
|
339 |
+
# Function to send a request and update history
|
340 |
+
def send_request(prompt: str, conversation_history: List[dict], model: object, config: dict, model_choice: str, system_prompt: str, temperature: float, progress=Progress(track_tqdm=True)) -> Tuple[str, List[dict]]:
|
341 |
+
"""
|
342 |
+
This function sends a request to a language model with the given prompt, conversation history, model configuration, model choice, system prompt, and temperature.
|
343 |
+
It constructs the full prompt by appending the new user prompt to the conversation history, generates a response from the model, and updates the conversation history with the new prompt and response.
|
344 |
+
If the model choice is specific to AWS Claude, it calls the `call_aws_claude` function; otherwise, it uses the `model.generate_content` method.
|
345 |
+
The function returns the response text and the updated conversation history.
|
346 |
+
"""
|
347 |
+
# Constructing the full prompt from the conversation history
|
348 |
+
full_prompt = "Conversation history:\n"
|
349 |
+
|
350 |
+
for entry in conversation_history:
|
351 |
+
role = entry['role'].capitalize() # Assuming the history is stored with 'role' and 'content'
|
352 |
+
message = ' '.join(entry['parts']) # Combining all parts of the message
|
353 |
+
full_prompt += f"{role}: {message}\n"
|
354 |
+
|
355 |
+
# Adding the new user prompt
|
356 |
+
full_prompt += f"\nUser: {prompt}"
|
357 |
+
|
358 |
+
# Print the full prompt for debugging purposes
|
359 |
+
#print("full_prompt:", full_prompt)
|
360 |
+
|
361 |
+
# Generate the model's response
|
362 |
+
if "gemini" in model_choice:
|
363 |
+
try:
|
364 |
+
response = model.generate_content(contents=full_prompt, generation_config=config)
|
365 |
+
except Exception as e:
|
366 |
+
# If fails, try again after 10 seconds in case there is a throttle limit
|
367 |
+
print(e)
|
368 |
+
try:
|
369 |
+
print("Calling Gemini model")
|
370 |
+
out_message = "API limit hit - waiting 30 seconds to retry."
|
371 |
+
print(out_message)
|
372 |
+
progress(0.5, desc=out_message)
|
373 |
+
time.sleep(30)
|
374 |
+
response = model.generate_content(contents=full_prompt, generation_config=config)
|
375 |
+
except Exception as e:
|
376 |
+
print(e)
|
377 |
+
return "", conversation_history
|
378 |
+
elif "claude" in model_choice:
|
379 |
+
try:
|
380 |
+
print("Calling AWS Claude model")
|
381 |
+
print("prompt:", prompt)
|
382 |
+
print("system_prompt:", system_prompt)
|
383 |
+
response = call_aws_claude(prompt, system_prompt, temperature, max_tokens, model_choice)
|
384 |
+
except Exception as e:
|
385 |
+
# If fails, try again after x seconds in case there is a throttle limit
|
386 |
+
print(e)
|
387 |
+
try:
|
388 |
+
out_message = "API limit hit - waiting 30 seconds to retry."
|
389 |
+
print(out_message)
|
390 |
+
progress(0.5, desc=out_message)
|
391 |
+
time.sleep(30)
|
392 |
+
response = call_aws_claude(prompt, system_prompt, temperature, max_tokens, model_choice)
|
393 |
+
|
394 |
+
except Exception as e:
|
395 |
+
print(e)
|
396 |
+
return "", conversation_history
|
397 |
+
else:
|
398 |
+
raise Exception("Model not found")
|
399 |
+
|
400 |
+
# Update the conversation history with the new prompt and response
|
401 |
+
conversation_history.append({"metadata":None, "options":None, "role": 'user', 'parts': [prompt]})
|
402 |
+
conversation_history.append({"metadata":None, "options":None, "role": "assistant", 'parts': [response.text]})
|
403 |
+
|
404 |
+
# Print the updated conversation history
|
405 |
+
#print("conversation_history:", conversation_history)
|
406 |
+
|
407 |
+
return response, conversation_history
|
408 |
+
|
409 |
+
def process_requests(prompts: List[str], system_prompt_with_table: str, conversation_history: List[dict], whole_conversation: List[str], whole_conversation_metadata: List[str], model: object, config: dict, model_choice: str, temperature: float, batch_no:int = 1, master:bool = False) -> Tuple[List[ResponseObject], List[dict], List[str], List[str]]:
|
410 |
+
"""
|
411 |
+
Processes a list of prompts by sending them to the model, appending the responses to the conversation history, and updating the whole conversation and metadata.
|
412 |
+
|
413 |
+
Args:
|
414 |
+
prompts (List[str]): A list of prompts to be processed.
|
415 |
+
system_prompt_with_table (str): The system prompt including a table.
|
416 |
+
conversation_history (List[dict]): The history of the conversation.
|
417 |
+
whole_conversation (List[str]): The complete conversation including prompts and responses.
|
418 |
+
whole_conversation_metadata (List[str]): Metadata about the whole conversation.
|
419 |
+
model (object): The model to use for processing the prompts.
|
420 |
+
config (dict): Configuration for the model.
|
421 |
+
model_choice (str): The choice of model to use.
|
422 |
+
temperature (float): The temperature parameter for the model.
|
423 |
+
batch_no (int): Batch number of the large language model request.
|
424 |
+
master (bool): Is this request for the master table.
|
425 |
+
|
426 |
+
Returns:
|
427 |
+
Tuple[List[ResponseObject], List[dict], List[str], List[str]]: A tuple containing the list of responses, the updated conversation history, the updated whole conversation, and the updated whole conversation metadata.
|
428 |
+
"""
|
429 |
+
responses = []
|
430 |
+
#for prompt in prompts:
|
431 |
+
|
432 |
+
response, conversation_history = send_request(prompts[0], conversation_history, model=model, config=config, model_choice=model_choice, system_prompt=system_prompt_with_table, temperature=temperature)
|
433 |
+
|
434 |
+
print(response.text)
|
435 |
+
#"Okay, I'm ready. What source are we discussing, and what's your question about it? Please provide as much context as possible so I can give you the best answer."]
|
436 |
+
print(response.usage_metadata)
|
437 |
+
responses.append(response)
|
438 |
+
|
439 |
+
# Create conversation txt object
|
440 |
+
whole_conversation.append(prompts[0])
|
441 |
+
whole_conversation.append(response.text)
|
442 |
+
|
443 |
+
# Create conversation metadata
|
444 |
+
if master == False:
|
445 |
+
whole_conversation_metadata.append(f"Query batch {batch_no} prompt {len(responses)} metadata:")
|
446 |
+
else:
|
447 |
+
whole_conversation_metadata.append(f"Query summary metadata:")
|
448 |
+
|
449 |
+
whole_conversation_metadata.append(str(response.usage_metadata))
|
450 |
+
|
451 |
+
return responses, conversation_history, whole_conversation, whole_conversation_metadata
|
452 |
+
|
453 |
+
def produce_streaming_answer_chatbot(
|
454 |
+
history:list,
|
455 |
+
full_prompt:str,
|
456 |
+
model_type:str,
|
457 |
+
temperature:float=temperature,
|
458 |
+
relevant_query_bool:bool=True,
|
459 |
+
chat_history:list[dict]=[{"metadata":None, "options":None, "role": 'user', "content": ""}],
|
460 |
+
max_new_tokens:int=max_new_tokens,
|
461 |
+
sample:bool=sample,
|
462 |
+
repetition_penalty:float=repetition_penalty,
|
463 |
+
top_p:float=top_p,
|
464 |
+
top_k:float=top_k,
|
465 |
+
max_tokens:int=max_tokens,
|
466 |
+
in_api_key:str=GEMINI_API_KEY
|
467 |
):
|
468 |
#print("Model type is: ", model_type)
|
469 |
|
|
|
473 |
|
474 |
# return history
|
475 |
|
476 |
+
history = chat_history
|
477 |
+
|
478 |
+
print("history at start of streaming function:", history)
|
479 |
|
480 |
if relevant_query_bool == False:
|
481 |
+
history.append({"metadata":None, "options":None, "role": "assistant", "content": 'No relevant query found. Please retry your question'})
|
|
|
482 |
|
483 |
yield history
|
484 |
return
|
485 |
|
486 |
if model_type == "Qwen 2 0.5B (small, fast)":
|
487 |
+
|
488 |
+
print("tokenizer:", tokenizer)
|
489 |
# Get the model and tokenizer, and tokenize the user text.
|
490 |
model_inputs = tokenizer(text=full_prompt, return_tensors="pt", return_attention_mask=False).to(torch_device)
|
491 |
|
|
|
503 |
top_k=top_k
|
504 |
)
|
505 |
|
506 |
+
print("model_object:", model_object)
|
507 |
|
508 |
+
t = Thread(target=model_object.generate, kwargs=generate_kwargs)
|
509 |
t.start()
|
510 |
|
511 |
# Pull the generated text from the streamer, and update the model output.
|
|
|
513 |
NUM_TOKENS=0
|
514 |
print('-'*4+'Start Generation'+'-'*4)
|
515 |
|
516 |
+
history.append({"metadata":None, "options":None, "role": "assistant", "content": ""})
|
517 |
+
|
518 |
for new_text in streamer:
|
519 |
try:
|
520 |
+
if new_text is None:
|
521 |
+
new_text = ""
|
522 |
+
history[-1]['content'] += new_text
|
523 |
+
NUM_TOKENS += 1
|
524 |
yield history
|
525 |
except Exception as e:
|
526 |
print(f"Error during text generation: {e}")
|
|
|
546 |
NUM_TOKENS=0
|
547 |
print('-'*4+'Start Generation'+'-'*4)
|
548 |
|
549 |
+
output = model_object(
|
550 |
full_prompt, **vars(gen_config))
|
551 |
|
552 |
+
history.append({"metadata":None, "options":None, "role": "assistant", "content": ""})
|
553 |
+
|
554 |
for out in output:
|
555 |
|
556 |
if "choices" in out and len(out["choices"]) > 0 and "text" in out["choices"][0]:
|
557 |
+
history[-1]['content'] += out["choices"][0]["text"]
|
558 |
NUM_TOKENS+=1
|
559 |
yield history
|
560 |
else:
|
|
|
565 |
print('-'*4+'End Generation'+'-'*4)
|
566 |
print(f'Num of generated tokens: {NUM_TOKENS}')
|
567 |
print(f'Time for complete generation: {time_generate}s')
|
568 |
+
print(f'Tokens per second: {NUM_TOKENS/time_generate}')
|
569 |
print(f'Time per token: {(time_generate/NUM_TOKENS)*1000}ms')
|
570 |
|
571 |
+
elif "claude" in model_type:
|
572 |
system_prompt = "You are answering questions from the user based on source material. Respond with short, factually correct answers."
|
573 |
|
574 |
+
print("full_prompt:", full_prompt)
|
575 |
+
|
576 |
+
if isinstance(full_prompt, str):
|
577 |
+
full_prompt = [full_prompt]
|
578 |
+
|
579 |
+
model = model_type
|
580 |
+
config = {}
|
581 |
+
|
582 |
+
responses, summary_conversation_history, whole_summary_conversation, whole_conversation_metadata = process_requests(full_prompt, system_prompt, conversation_history=[], whole_conversation=[], whole_conversation_metadata=[], model=model, config = config, model_choice = model_type, temperature = temperature)
|
583 |
+
|
584 |
+
if isinstance(responses[-1], ResponseObject):
|
585 |
+
response_texts = [resp.text for resp in responses]
|
586 |
+
elif "choices" in responses[-1]:
|
587 |
+
response_texts = [resp["choices"][0]['text'] for resp in responses]
|
588 |
+
else:
|
589 |
+
response_texts = [resp.text for resp in responses]
|
590 |
+
|
591 |
+
latest_response_text = response_texts[-1]
|
592 |
|
|
|
|
|
|
|
|
|
|
|
|
|
593 |
# Update the conversation history with the new prompt and response
|
594 |
+
clean_text = re.sub(r'[\n\t\r]', ' ', latest_response_text) # Replace newlines, tabs, and carriage returns with a space
|
595 |
+
clean_response_text = re.sub(r'[^\x20-\x7E]', '', clean_text).strip() # Remove all non-ASCII printable characters
|
596 |
|
597 |
+
history.append({"metadata":None, "options":None, "role": "assistant", "content": ''})
|
|
|
598 |
|
599 |
+
for char in clean_response_text:
|
600 |
+
time.sleep(0.005)
|
601 |
+
history[-1]['content'] += char
|
602 |
+
yield history
|
603 |
+
|
604 |
+
elif "gemini" in model_type:
|
605 |
+
print("Using Gemini model:", model_type)
|
606 |
+
print("full_prompt:", full_prompt)
|
607 |
+
|
608 |
+
if isinstance(full_prompt, str):
|
609 |
+
full_prompt = [full_prompt]
|
610 |
+
|
611 |
+
system_prompt = "You are answering questions from the user based on source material. Respond with short, factually correct answers."
|
612 |
+
|
613 |
+
model, config = construct_gemini_generative_model(GEMINI_API_KEY, temperature, model_type, system_prompt, max_tokens)
|
614 |
+
|
615 |
+
responses, summary_conversation_history, whole_summary_conversation, whole_conversation_metadata = process_requests(full_prompt, system_prompt, conversation_history=[], whole_conversation=[], whole_conversation_metadata=[], model=model, config = config, model_choice = model_type, temperature = temperature)
|
616 |
+
|
617 |
+
if isinstance(responses[-1], ResponseObject):
|
618 |
+
response_texts = [resp.text for resp in responses]
|
619 |
+
elif "choices" in responses[-1]:
|
620 |
+
response_texts = [resp["choices"][0]['text'] for resp in responses]
|
621 |
+
else:
|
622 |
+
response_texts = [resp.text for resp in responses]
|
623 |
+
|
624 |
+
latest_response_text = response_texts[-1]
|
625 |
+
|
626 |
+
clean_text = re.sub(r'[\n\t\r]', ' ', latest_response_text) # Replace newlines, tabs, and carriage returns with a space
|
627 |
+
clean_response_text = re.sub(r'[^\x20-\x7E]', '', clean_text).strip() # Remove all non-ASCII printable characters
|
628 |
+
|
629 |
+
history.append({"metadata":None, "options":None, "role": "assistant", "content": ''})
|
630 |
+
|
631 |
+
for char in clean_response_text:
|
632 |
+
time.sleep(0.005)
|
633 |
+
history[-1]['content'] += char
|
634 |
+
yield history
|
635 |
+
|
636 |
+
print("history at end of function:", history)
|
637 |
|
638 |
# Chat helper functions
|
639 |
|
|
|
712 |
|
713 |
docs = vectorstore.similarity_search_with_score(new_question_kworded, k=k_val)
|
714 |
|
|
|
|
|
|
|
715 |
# Keep only documents with a certain score
|
716 |
docs_len = [len(x[0].page_content) for x in docs]
|
717 |
docs_scores = [x[1] for x in docs]
|
|
|
808 |
# 3rd level check on retrieved docs with SVM retriever
|
809 |
# Check the type of the embeddings object
|
810 |
embeddings_type = type(embeddings)
|
|
|
|
|
811 |
|
|
|
812 |
|
|
|
813 |
#hf_embeddings = HuggingFaceEmbeddings(**embeddings)
|
814 |
hf_embeddings = embeddings
|
815 |
|
|
|
859 |
# Make df of best options
|
860 |
doc_df = create_doc_df(docs_keep_out)
|
861 |
|
|
|
|
|
|
|
|
|
862 |
return docs_keep_as_doc, doc_df, docs_keep_out
|
863 |
|
864 |
def get_expanded_passages(vectorstore, docs, width):
|
|
|
948 |
|
949 |
return expanded_docs, doc_df
|
950 |
|
951 |
+
def highlight_found_text(chat_history: list[dict], source_texts: list[dict], hlt_chunk_size:int=hlt_chunk_size, hlt_strat:List=hlt_strat, hlt_overlap:int=hlt_overlap) -> str:
|
952 |
"""
|
953 |
+
Highlights occurrences of chat_history within source_texts.
|
954 |
|
955 |
Parameters:
|
956 |
+
- chat_history (str): The text to be searched for within source_texts.
|
957 |
+
- source_texts (str): The text within which chat_history occurrences will be highlighted.
|
958 |
|
959 |
Returns:
|
960 |
+
- str: A string with occurrences of chat_history highlighted.
|
961 |
|
962 |
Example:
|
963 |
>>> highlight_found_text("world", "Hello, world! This is a test. Another world awaits.")
|
|
|
971 |
return text[i][0].replace(" ", " ").strip()
|
972 |
else:
|
973 |
return ""
|
974 |
+
|
975 |
+
print("chat_history:", chat_history)
|
976 |
+
|
977 |
+
response_text = next(
|
978 |
+
(entry['content'] for entry in reversed(chat_history) if entry.get('role') == 'assistant'),
|
979 |
+
"")
|
980 |
|
981 |
+
print("response_text:", response_text)
|
982 |
+
|
983 |
+
source_texts = extract_text_from_input(source_texts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
984 |
|
985 |
+
print("source_texts:", source_texts)
|
986 |
|
987 |
text_splitter = RecursiveCharacterTextSplitter(
|
988 |
chunk_size=hlt_chunk_size,
|
989 |
separators=hlt_strat,
|
990 |
chunk_overlap=hlt_overlap,
|
991 |
)
|
992 |
+
sections = text_splitter.split_text(response_text)
|
993 |
|
994 |
found_positions = {}
|
995 |
for x in sections:
|
996 |
text_start_pos = 0
|
997 |
while text_start_pos != -1:
|
998 |
+
text_start_pos = source_texts.find(x, text_start_pos)
|
999 |
if text_start_pos != -1:
|
1000 |
found_positions[text_start_pos] = text_start_pos + len(x)
|
1001 |
text_start_pos += 1
|
|
|
1018 |
prev_end = 0
|
1019 |
for start, end in combined_positions:
|
1020 |
if end-start > 15: # Only combine if there is a significant amount of matched text. Avoids picking up single words like 'and' etc.
|
1021 |
+
pos_tokens.append(source_texts[prev_end:start])
|
1022 |
+
pos_tokens.append('<mark style="color:black;">' + source_texts[start:end] + '</mark>')
|
1023 |
prev_end = end
|
1024 |
+
pos_tokens.append(source_texts[prev_end:])
|
1025 |
+
|
1026 |
+
out_pos_tokens = "".join(pos_tokens)
|
1027 |
|
1028 |
+
print("out_pos_tokens:", out_pos_tokens)
|
1029 |
+
|
1030 |
+
return out_pos_tokens
|
1031 |
|
1032 |
|
1033 |
# # Chat history functions
|
1034 |
|
1035 |
def clear_chat(chat_history_state, sources, chat_message, current_topic):
|
1036 |
+
chat_history_state = None
|
1037 |
sources = ''
|
1038 |
+
chat_message = None
|
1039 |
current_topic = ''
|
1040 |
|
1041 |
return chat_history_state, sources, chat_message, current_topic
|
|
|
1126 |
for word in tokens_without_sw:
|
1127 |
if word not in ordered_tokens:
|
1128 |
ordered_tokens.add(word)
|
1129 |
+
result.append(word)
|
|
|
1130 |
|
1131 |
|
1132 |
new_question_keywords = ' '.join(result)
|
|
|
1135 |
def remove_q_ner_extractor(question):
|
1136 |
|
1137 |
predict_out = ner_model.predict(question)
|
|
|
|
|
|
|
1138 |
predict_tokens = [' '.join(v for k, v in d.items() if k == 'span') for d in predict_out]
|
1139 |
|
1140 |
# Remove duplicate words while preserving order
|
|
|
1186 |
return keywords_list
|
1187 |
|
1188 |
# Gradio functions
|
1189 |
+
def turn_off_interactivity():
|
1190 |
+
return gr.Textbox(interactive=False), gr.Button(interactive=False)
|
1191 |
|
1192 |
def restore_interactivity():
|
1193 |
+
return gr.Textbox(interactive=True), gr.Button(interactive=True)
|
1194 |
|
1195 |
def update_message(dropdown_value):
|
1196 |
return gr.Textbox(value=dropdown_value)
|
chatfuncs/config.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import socket
|
4 |
+
import logging
|
5 |
+
from datetime import datetime
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
today_rev = datetime.now().strftime("%Y%m%d")
|
9 |
+
HOST_NAME = socket.gethostname()
|
10 |
+
|
11 |
+
# Set or retrieve configuration variables for the redaction app
|
12 |
+
|
13 |
+
def get_or_create_env_var(var_name:str, default_value:str, print_val:bool=False):
|
14 |
+
'''
|
15 |
+
Get an environmental variable, and set it to a default value if it doesn't exist
|
16 |
+
'''
|
17 |
+
# Get the environment variable if it exists
|
18 |
+
value = os.environ.get(var_name)
|
19 |
+
|
20 |
+
# If it doesn't exist, set the environment variable to the default value
|
21 |
+
if value is None:
|
22 |
+
os.environ[var_name] = default_value
|
23 |
+
value = default_value
|
24 |
+
|
25 |
+
if print_val == True:
|
26 |
+
print(f'The value of {var_name} is {value}')
|
27 |
+
|
28 |
+
return value
|
29 |
+
|
30 |
+
def ensure_folder_exists(output_folder:str):
|
31 |
+
"""Checks if the specified folder exists, creates it if not."""
|
32 |
+
|
33 |
+
if not os.path.exists(output_folder):
|
34 |
+
# Create the folder if it doesn't exist
|
35 |
+
os.makedirs(output_folder, exist_ok=True)
|
36 |
+
print(f"Created the {output_folder} folder.")
|
37 |
+
else:
|
38 |
+
print(f"The {output_folder} folder already exists.")
|
39 |
+
|
40 |
+
def add_folder_to_path(folder_path: str):
|
41 |
+
'''
|
42 |
+
Check if a folder exists on your system. If so, get the absolute path and then add it to the system Path variable if it doesn't already exist. Function is only relevant for locally-created executable files based on this app (when using pyinstaller it creates a _internal folder that contains tesseract and poppler. These need to be added to the system path to enable the app to run)
|
43 |
+
'''
|
44 |
+
|
45 |
+
if os.path.exists(folder_path) and os.path.isdir(folder_path):
|
46 |
+
print(folder_path, "folder exists.")
|
47 |
+
|
48 |
+
# Resolve relative path to absolute path
|
49 |
+
absolute_path = os.path.abspath(folder_path)
|
50 |
+
|
51 |
+
current_path = os.environ['PATH']
|
52 |
+
if absolute_path not in current_path.split(os.pathsep):
|
53 |
+
full_path_extension = absolute_path + os.pathsep + current_path
|
54 |
+
os.environ['PATH'] = full_path_extension
|
55 |
+
#print(f"Updated PATH with: ", full_path_extension)
|
56 |
+
else:
|
57 |
+
print(f"Directory {folder_path} already exists in PATH.")
|
58 |
+
else:
|
59 |
+
print(f"Folder not found at {folder_path} - not added to PATH")
|
60 |
+
|
61 |
+
ensure_folder_exists("config/")
|
62 |
+
|
63 |
+
# If you have an aws_config env file in the config folder, you can load in app variables this way, e.g. 'config/app_config.env'
|
64 |
+
APP_CONFIG_PATH = get_or_create_env_var('APP_CONFIG_PATH', 'config/app_config.env') # e.g. config/app_config.env
|
65 |
+
|
66 |
+
if APP_CONFIG_PATH:
|
67 |
+
if os.path.exists(APP_CONFIG_PATH):
|
68 |
+
print(f"Loading app variables from config file {APP_CONFIG_PATH}")
|
69 |
+
load_dotenv(APP_CONFIG_PATH)
|
70 |
+
else: print("App config file not found at location:", APP_CONFIG_PATH)
|
71 |
+
|
72 |
+
# Report logging to console?
|
73 |
+
LOGGING = get_or_create_env_var('LOGGING', 'False')
|
74 |
+
|
75 |
+
if LOGGING == 'True':
|
76 |
+
# Configure logging
|
77 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
78 |
+
|
79 |
+
###
|
80 |
+
# AWS CONFIG
|
81 |
+
###
|
82 |
+
|
83 |
+
# If you have an aws_config env file in the config folder, you can load in AWS keys this way, e.g. 'env/aws_config.env'
|
84 |
+
AWS_CONFIG_PATH = get_or_create_env_var('AWS_CONFIG_PATH', '') # e.g. config/aws_config.env
|
85 |
+
|
86 |
+
if AWS_CONFIG_PATH:
|
87 |
+
if os.path.exists(AWS_CONFIG_PATH):
|
88 |
+
print(f"Loading AWS variables from config file {AWS_CONFIG_PATH}")
|
89 |
+
load_dotenv(AWS_CONFIG_PATH)
|
90 |
+
else: print("AWS config file not found at location:", AWS_CONFIG_PATH)
|
91 |
+
|
92 |
+
RUN_AWS_FUNCTIONS = get_or_create_env_var("RUN_AWS_FUNCTIONS", "0")
|
93 |
+
|
94 |
+
AWS_REGION = get_or_create_env_var('AWS_REGION', '')
|
95 |
+
|
96 |
+
AWS_DEFAULT_REGION = get_or_create_env_var('AWS_DEFAULT_REGION', '')
|
97 |
+
|
98 |
+
AWS_CLIENT_ID = get_or_create_env_var('AWS_CLIENT_ID', '')
|
99 |
+
|
100 |
+
AWS_CLIENT_SECRET = get_or_create_env_var('AWS_CLIENT_SECRET', '')
|
101 |
+
|
102 |
+
AWS_USER_POOL_ID = get_or_create_env_var('AWS_USER_POOL_ID', '')
|
103 |
+
|
104 |
+
AWS_ACCESS_KEY = get_or_create_env_var('AWS_ACCESS_KEY', '')
|
105 |
+
if AWS_ACCESS_KEY: print(f'AWS_ACCESS_KEY found in environment variables')
|
106 |
+
|
107 |
+
AWS_SECRET_KEY = get_or_create_env_var('AWS_SECRET_KEY', '')
|
108 |
+
if AWS_SECRET_KEY: print(f'AWS_SECRET_KEY found in environment variables')
|
109 |
+
|
110 |
+
DOCUMENT_REDACTION_BUCKET = get_or_create_env_var('DOCUMENT_REDACTION_BUCKET', '')
|
111 |
+
|
112 |
+
# Custom headers e.g. if routing traffic through Cloudfront
|
113 |
+
# Retrieving or setting CUSTOM_HEADER
|
114 |
+
CUSTOM_HEADER = get_or_create_env_var('CUSTOM_HEADER', '')
|
115 |
+
#if CUSTOM_HEADER: print(f'CUSTOM_HEADER found')
|
116 |
+
|
117 |
+
# Retrieving or setting CUSTOM_HEADER_VALUE
|
118 |
+
CUSTOM_HEADER_VALUE = get_or_create_env_var('CUSTOM_HEADER_VALUE', '')
|
119 |
+
#if CUSTOM_HEADER_VALUE: print(f'CUSTOM_HEADER_VALUE found')
|
120 |
+
|
121 |
+
###
|
122 |
+
# File I/O config
|
123 |
+
###
|
124 |
+
SESSION_OUTPUT_FOLDER = get_or_create_env_var('SESSION_OUTPUT_FOLDER', 'False') # i.e. do you want your input and output folders saved within a subfolder based on session hash value within output/input folders
|
125 |
+
|
126 |
+
OUTPUT_FOLDER = get_or_create_env_var('GRADIO_OUTPUT_FOLDER', 'output/') # 'output/'
|
127 |
+
INPUT_FOLDER = get_or_create_env_var('GRADIO_INPUT_FOLDER', 'input/') # 'input/'
|
128 |
+
|
129 |
+
ensure_folder_exists(OUTPUT_FOLDER)
|
130 |
+
ensure_folder_exists(INPUT_FOLDER)
|
131 |
+
|
132 |
+
# Allow for files to be saved in a temporary folder for increased security in some instances
|
133 |
+
if OUTPUT_FOLDER == "TEMP" or INPUT_FOLDER == "TEMP":
|
134 |
+
# Create a temporary directory
|
135 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
136 |
+
print(f'Temporary directory created at: {temp_dir}')
|
137 |
+
|
138 |
+
if OUTPUT_FOLDER == "TEMP": OUTPUT_FOLDER = temp_dir + "/"
|
139 |
+
if INPUT_FOLDER == "TEMP": INPUT_FOLDER = temp_dir + "/"
|
140 |
+
|
141 |
+
# By default, logs are put into a subfolder of today's date and the host name of the instance running the app. This is to avoid at all possible the possibility of log files from one instance overwriting the logs of another instance on S3. If running the app on one system always, or just locally, it is not necessary to make the log folders so specific.
|
142 |
+
# Another way to address this issue would be to write logs to another type of storage, e.g. database such as dynamodb. I may look into this in future.
|
143 |
+
|
144 |
+
USE_LOG_SUBFOLDERS = get_or_create_env_var('USE_LOG_SUBFOLDERS', 'True')
|
145 |
+
|
146 |
+
if USE_LOG_SUBFOLDERS == "True":
|
147 |
+
day_log_subfolder = today_rev + '/'
|
148 |
+
host_name_subfolder = HOST_NAME + '/'
|
149 |
+
full_log_subfolder = day_log_subfolder + host_name_subfolder
|
150 |
+
else:
|
151 |
+
full_log_subfolder = ""
|
152 |
+
|
153 |
+
FEEDBACK_LOGS_FOLDER = get_or_create_env_var('FEEDBACK_LOGS_FOLDER', 'feedback/' + full_log_subfolder)
|
154 |
+
ACCESS_LOGS_FOLDER = get_or_create_env_var('ACCESS_LOGS_FOLDER', 'logs/' + full_log_subfolder)
|
155 |
+
USAGE_LOGS_FOLDER = get_or_create_env_var('USAGE_LOGS_FOLDER', 'usage/' + full_log_subfolder)
|
156 |
+
|
157 |
+
ensure_folder_exists(FEEDBACK_LOGS_FOLDER)
|
158 |
+
ensure_folder_exists(ACCESS_LOGS_FOLDER)
|
159 |
+
ensure_folder_exists(USAGE_LOGS_FOLDER)
|
160 |
+
|
161 |
+
# Should the redacted file name be included in the logs? In some instances, the names of the files themselves could be sensitive, and should not be disclosed beyond the app. So, by default this is false.
|
162 |
+
DISPLAY_FILE_NAMES_IN_LOGS = get_or_create_env_var('DISPLAY_FILE_NAMES_IN_LOGS', 'False')
|
163 |
+
|
164 |
+
###
|
165 |
+
# RUN CONFIG
|
166 |
+
GEMINI_API_KEY = get_or_create_env_var('GEMINI_API_KEY', '')
|
167 |
+
|
168 |
+
|
169 |
+
# Number of pages to loop through before breaking the function and restarting from the last finished page (not currently activated).
|
170 |
+
PAGE_BREAK_VALUE = get_or_create_env_var('PAGE_BREAK_VALUE', '99999')
|
171 |
+
|
172 |
+
MAX_TIME_VALUE = get_or_create_env_var('MAX_TIME_VALUE', '999999')
|
173 |
+
|
174 |
+
###
|
175 |
+
# APP RUN CONFIG
|
176 |
+
###
|
177 |
+
|
178 |
+
# Get some environment variables and Launch the Gradio app
|
179 |
+
COGNITO_AUTH = get_or_create_env_var('COGNITO_AUTH', '0')
|
180 |
+
|
181 |
+
RUN_DIRECT_MODE = get_or_create_env_var('RUN_DIRECT_MODE', '0')
|
182 |
+
|
183 |
+
MAX_QUEUE_SIZE = int(get_or_create_env_var('MAX_QUEUE_SIZE', '5'))
|
184 |
+
|
185 |
+
MAX_FILE_SIZE = get_or_create_env_var('MAX_FILE_SIZE', '250mb')
|
186 |
+
|
187 |
+
GRADIO_SERVER_PORT = int(get_or_create_env_var('GRADIO_SERVER_PORT', '7860'))
|
188 |
+
|
189 |
+
ROOT_PATH = get_or_create_env_var('ROOT_PATH', '')
|
190 |
+
|
191 |
+
DEFAULT_CONCURRENCY_LIMIT = get_or_create_env_var('DEFAULT_CONCURRENCY_LIMIT', '3')
|
192 |
+
|
193 |
+
GET_DEFAULT_ALLOW_LIST = get_or_create_env_var('GET_DEFAULT_ALLOW_LIST', 'False')
|
194 |
+
|
195 |
+
ALLOW_LIST_PATH = get_or_create_env_var('ALLOW_LIST_PATH', '') # config/default_allow_list.csv
|
196 |
+
|
197 |
+
S3_ALLOW_LIST_PATH = get_or_create_env_var('S3_ALLOW_LIST_PATH', '') # default_allow_list.csv # This is a path within the DOCUMENT_REDACTION_BUCKET
|
198 |
+
|
199 |
+
if ALLOW_LIST_PATH: OUTPUT_ALLOW_LIST_PATH = ALLOW_LIST_PATH
|
200 |
+
else: OUTPUT_ALLOW_LIST_PATH = 'config/default_allow_list.csv'
|
201 |
+
|
202 |
+
SHOW_COSTS = get_or_create_env_var('SHOW_COSTS', 'False')
|
203 |
+
|
204 |
+
GET_COST_CODES = get_or_create_env_var('GET_COST_CODES', 'False')
|
205 |
+
|
206 |
+
DEFAULT_COST_CODE = get_or_create_env_var('DEFAULT_COST_CODE', '')
|
207 |
+
|
208 |
+
COST_CODES_PATH = get_or_create_env_var('COST_CODES_PATH', '') # 'config/COST_CENTRES.csv' # file should be a csv file with a single table in it that has two columns with a header. First column should contain cost codes, second column should contain a name or description for the cost code
|
209 |
+
|
210 |
+
S3_COST_CODES_PATH = get_or_create_env_var('S3_COST_CODES_PATH', '') # COST_CENTRES.csv # This is a path within the DOCUMENT_REDACTION_BUCKET
|
211 |
+
|
212 |
+
if COST_CODES_PATH: OUTPUT_COST_CODES_PATH = COST_CODES_PATH
|
213 |
+
else: OUTPUT_COST_CODES_PATH = 'config/COST_CENTRES.csv'
|
214 |
+
|
215 |
+
ENFORCE_COST_CODES = get_or_create_env_var('ENFORCE_COST_CODES', 'False') # If you have cost codes listed, is it compulsory to choose one before redacting?
|
216 |
+
|
217 |
+
if ENFORCE_COST_CODES == 'True': GET_COST_CODES = 'True'
|
chatfuncs/model_load.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# Currently set gpu_layers to 0 even with cuda due to persistent bugs in implementation with cuda
|
4 |
+
if torch.cuda.is_available():
|
5 |
+
torch_device = "cuda"
|
6 |
+
gpu_layers = 100
|
7 |
+
else:
|
8 |
+
torch_device = "cpu"
|
9 |
+
gpu_layers = 0
|
10 |
+
|
11 |
+
print("Running on device:", torch_device)
|
12 |
+
threads = 8 #torch.get_num_threads()
|
13 |
+
print("CPU threads:", threads)
|
14 |
+
|
15 |
+
# Qwen 2 0.5B (small, fast) Model parameters
|
16 |
+
temperature: float = 0.1
|
17 |
+
top_k: int = 3
|
18 |
+
top_p: float = 1
|
19 |
+
repetition_penalty: float = 1.15
|
20 |
+
flan_alpaca_repetition_penalty: float = 1.3
|
21 |
+
last_n_tokens: int = 64
|
22 |
+
max_new_tokens: int = 1024
|
23 |
+
seed: int = 42
|
24 |
+
reset: bool = False
|
25 |
+
stream: bool = True
|
26 |
+
threads: int = threads
|
27 |
+
batch_size:int = 256
|
28 |
+
context_length:int = 2048
|
29 |
+
sample = True
|
30 |
+
|
31 |
+
# Bedrock parameters
|
32 |
+
max_tokens = 4096
|
33 |
+
|
34 |
+
|
35 |
+
class CtransInitConfig_gpu:
|
36 |
+
def __init__(self,
|
37 |
+
last_n_tokens=last_n_tokens,
|
38 |
+
seed=seed,
|
39 |
+
n_threads=threads,
|
40 |
+
n_batch=batch_size,
|
41 |
+
n_ctx=max_tokens,
|
42 |
+
n_gpu_layers=gpu_layers):
|
43 |
+
|
44 |
+
self.last_n_tokens = last_n_tokens
|
45 |
+
self.seed = seed
|
46 |
+
self.n_threads = n_threads
|
47 |
+
self.n_batch = n_batch
|
48 |
+
self.n_ctx = n_ctx
|
49 |
+
self.n_gpu_layers = n_gpu_layers
|
50 |
+
# self.stop: list[str] = field(default_factory=lambda: [stop_string])
|
51 |
+
|
52 |
+
def update_gpu(self, new_value):
|
53 |
+
self.n_gpu_layers = new_value
|
54 |
+
|
55 |
+
class CtransInitConfig_cpu(CtransInitConfig_gpu):
|
56 |
+
def __init__(self):
|
57 |
+
super().__init__()
|
58 |
+
self.n_gpu_layers = 0
|
59 |
+
|
60 |
+
gpu_config = CtransInitConfig_gpu()
|
61 |
+
cpu_config = CtransInitConfig_cpu()
|
62 |
+
|
63 |
+
|
64 |
+
class CtransGenGenerationConfig:
|
65 |
+
def __init__(self, temperature=temperature,
|
66 |
+
top_k=top_k,
|
67 |
+
top_p=top_p,
|
68 |
+
repeat_penalty=repetition_penalty,
|
69 |
+
seed=seed,
|
70 |
+
stream=stream,
|
71 |
+
max_tokens=max_new_tokens
|
72 |
+
):
|
73 |
+
self.temperature = temperature
|
74 |
+
self.top_k = top_k
|
75 |
+
self.top_p = top_p
|
76 |
+
self.repeat_penalty = repeat_penalty
|
77 |
+
self.seed = seed
|
78 |
+
self.max_tokens=max_tokens
|
79 |
+
self.stream = stream
|
80 |
+
|
81 |
+
def update_temp(self, new_value):
|
82 |
+
self.temperature = new_value
|
chatfuncs/prompts.py
CHANGED
@@ -23,8 +23,7 @@ QUESTION - {question}
|
|
23 |
"""
|
24 |
|
25 |
|
26 |
-
instruction_prompt_template_orca = """
|
27 |
-
### System:
|
28 |
You are an AI assistant that follows instruction extremely well. Help as much as you can.
|
29 |
### User:
|
30 |
Answer the QUESTION with a short response using information from the following CONTENT.
|
@@ -33,8 +32,7 @@ CONTENT: {summaries}
|
|
33 |
|
34 |
### Response:"""
|
35 |
|
36 |
-
instruction_prompt_template_orca_quote = """
|
37 |
-
### System:
|
38 |
You are an AI assistant that follows instruction extremely well. Help as much as you can.
|
39 |
### User:
|
40 |
Quote text from the CONTENT to answer the QUESTION below.
|
|
|
23 |
"""
|
24 |
|
25 |
|
26 |
+
instruction_prompt_template_orca = """### System:
|
|
|
27 |
You are an AI assistant that follows instruction extremely well. Help as much as you can.
|
28 |
### User:
|
29 |
Answer the QUESTION with a short response using information from the following CONTENT.
|
|
|
32 |
|
33 |
### Response:"""
|
34 |
|
35 |
+
instruction_prompt_template_orca_quote = """### System:
|
|
|
36 |
You are an AI assistant that follows instruction extremely well. Help as much as you can.
|
37 |
### User:
|
38 |
Quote text from the CONTENT to answer the QUESTION below.
|
requirements.txt
CHANGED
@@ -22,3 +22,4 @@ PyStemmer==2.2.0.3
|
|
22 |
scipy==1.15.2
|
23 |
numpy==1.26.4
|
24 |
boto3==1.38.0
|
|
|
|
22 |
scipy==1.15.2
|
23 |
numpy==1.26.4
|
24 |
boto3==1.38.0
|
25 |
+
python-dotenv==1.1.0
|
requirements_gpu.txt
CHANGED
@@ -20,4 +20,5 @@ bm25s==0.2.12
|
|
20 |
PyStemmer==2.2.0.3
|
21 |
scipy==1.15.2
|
22 |
numpy==1.26.4
|
23 |
-
boto3==1.38.0
|
|
|
|
20 |
PyStemmer==2.2.0.3
|
21 |
scipy==1.15.2
|
22 |
numpy==1.26.4
|
23 |
+
boto3==1.38.0
|
24 |
+
python-dotenv==1.1.0
|