Spaces:
Sleeping
Sleeping
added llm_qa_chain_with_memory.py
Browse files- .gitignore +1 -0
- Makefile +4 -1
- app.py +9 -4
- app_modules/init.py +10 -1
- app_modules/llm_qa_chain_with_memory.py +32 -0
- app_modules/utils.py +4 -0
- qa_chain_with_memory_test.py +104 -0
.gitignore
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
*.out
|
|
|
2 |
pdfs/
|
3 |
.vscode/
|
4 |
|
|
|
1 |
*.out
|
2 |
+
*.log
|
3 |
pdfs/
|
4 |
.vscode/
|
5 |
|
Makefile
CHANGED
@@ -5,8 +5,11 @@ start:
|
|
5 |
test:
|
6 |
python qa_chain_test.py
|
7 |
|
|
|
|
|
|
|
8 |
chat:
|
9 |
-
python
|
10 |
|
11 |
ingest:
|
12 |
python ingest.py
|
|
|
5 |
test:
|
6 |
python qa_chain_test.py
|
7 |
|
8 |
+
long-test:
|
9 |
+
python qa_chain_with_memory_test.py 100
|
10 |
+
|
11 |
chat:
|
12 |
+
python qa_chain_with_memory_test.py chat
|
13 |
|
14 |
ingest:
|
15 |
python ingest.py
|
app.py
CHANGED
@@ -8,6 +8,8 @@ from timeit import default_timer as timer
|
|
8 |
|
9 |
import gradio as gr
|
10 |
|
|
|
|
|
11 |
from app_modules.init import app_init
|
12 |
from app_modules.utils import print_llm_response
|
13 |
|
@@ -29,10 +31,13 @@ href = (
|
|
29 |
)
|
30 |
|
31 |
title = "Chat with PCI DSS v4"
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
36 |
|
37 |
description = f"""\
|
38 |
<div align="left">
|
|
|
8 |
|
9 |
import gradio as gr
|
10 |
|
11 |
+
os.environ["USER_CONVERSATION_SUMMARY_BUFFER_MEMORY"] = "true"
|
12 |
+
|
13 |
from app_modules.init import app_init
|
14 |
from app_modules.utils import print_llm_response
|
15 |
|
|
|
31 |
)
|
32 |
|
33 |
title = "Chat with PCI DSS v4"
|
34 |
+
|
35 |
+
questions_file_path = os.environ.get("QUESTIONS_FILE_PATH")
|
36 |
+
|
37 |
+
# Open the file for reading
|
38 |
+
with open(questions_file_path, "r") as file:
|
39 |
+
examples = file.readlines()
|
40 |
+
examples = [example.strip() for example in examples]
|
41 |
|
42 |
description = f"""\
|
43 |
<div align="left">
|
app_modules/init.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
"""Main entrypoint for the app."""
|
|
|
2 |
import os
|
3 |
from timeit import default_timer as timer
|
4 |
from typing import List, Optional
|
@@ -9,7 +10,6 @@ from langchain.vectorstores.chroma import Chroma
|
|
9 |
from langchain.vectorstores.faiss import FAISS
|
10 |
|
11 |
from app_modules.llm_loader import LLMLoader
|
12 |
-
from app_modules.llm_qa_chain import QAChain
|
13 |
from app_modules.utils import get_device_types, init_settings
|
14 |
|
15 |
found_dotenv = find_dotenv(".env")
|
@@ -27,6 +27,15 @@ if os.environ.get("LANGCHAIN_DEBUG") == "true":
|
|
27 |
|
28 |
langchain.debug = True
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
def app_init():
|
32 |
# https://github.com/huggingface/transformers/issues/17611
|
|
|
1 |
"""Main entrypoint for the app."""
|
2 |
+
|
3 |
import os
|
4 |
from timeit import default_timer as timer
|
5 |
from typing import List, Optional
|
|
|
10 |
from langchain.vectorstores.faiss import FAISS
|
11 |
|
12 |
from app_modules.llm_loader import LLMLoader
|
|
|
13 |
from app_modules.utils import get_device_types, init_settings
|
14 |
|
15 |
found_dotenv = find_dotenv(".env")
|
|
|
27 |
|
28 |
langchain.debug = True
|
29 |
|
30 |
+
if os.environ.get("USER_CONVERSATION_SUMMARY_BUFFER_MEMORY") == "true":
|
31 |
+
from app_modules.llm_qa_chain_with_memory import QAChain
|
32 |
+
|
33 |
+
print("using llm_qa_chain_with_memory")
|
34 |
+
else:
|
35 |
+
from app_modules.llm_qa_chain import QAChain
|
36 |
+
|
37 |
+
print("using llm_qa_chain")
|
38 |
+
|
39 |
|
40 |
def app_init():
|
41 |
# https://github.com/huggingface/transformers/issues/17611
|
app_modules/llm_qa_chain_with_memory.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.chains import ConversationalRetrievalChain
|
2 |
+
from langchain.chains.base import Chain
|
3 |
+
from langchain.memory import ConversationSummaryBufferMemory
|
4 |
+
|
5 |
+
from app_modules.llm_inference import LLMInference
|
6 |
+
|
7 |
+
|
8 |
+
class QAChain(LLMInference):
|
9 |
+
def __init__(self, vectorstore, llm_loader):
|
10 |
+
super().__init__(llm_loader)
|
11 |
+
self.vectorstore = vectorstore
|
12 |
+
|
13 |
+
def create_chain(self) -> Chain:
|
14 |
+
memory = ConversationSummaryBufferMemory(
|
15 |
+
llm=self.llm_loader.llm,
|
16 |
+
output_key="answer",
|
17 |
+
memory_key="chat_history",
|
18 |
+
max_token_limit=1024,
|
19 |
+
return_messages=True,
|
20 |
+
)
|
21 |
+
qa = ConversationalRetrievalChain.from_llm(
|
22 |
+
self.llm_loader.llm,
|
23 |
+
memory=memory,
|
24 |
+
chain_type="stuff",
|
25 |
+
retriever=self.vectorstore.as_retriever(
|
26 |
+
search_kwargs=self.llm_loader.search_kwargs
|
27 |
+
),
|
28 |
+
get_chat_history=lambda h: h,
|
29 |
+
return_source_documents=True,
|
30 |
+
)
|
31 |
+
|
32 |
+
return qa
|
app_modules/utils.py
CHANGED
@@ -85,6 +85,10 @@ def print_llm_response(llm_response):
|
|
85 |
source["page_content"] if "page_content" in source else source.page_content
|
86 |
)
|
87 |
|
|
|
|
|
|
|
|
|
88 |
|
89 |
def get_device_types():
|
90 |
print("Running on: ", platform.platform())
|
|
|
85 |
source["page_content"] if "page_content" in source else source.page_content
|
86 |
)
|
87 |
|
88 |
+
if "chat_history" in llm_response:
|
89 |
+
print("\nChat History:")
|
90 |
+
print(llm_response["chat_history"])
|
91 |
+
|
92 |
|
93 |
def get_device_types():
|
94 |
print("Running on: ", platform.platform())
|
qa_chain_with_memory_test.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from timeit import default_timer as timer
|
4 |
+
|
5 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
6 |
+
from langchain.schema import LLMResult
|
7 |
+
|
8 |
+
os.environ["USER_CONVERSATION_SUMMARY_BUFFER_MEMORY"] = "true"
|
9 |
+
|
10 |
+
from app_modules.init import app_init
|
11 |
+
from app_modules.utils import print_llm_response
|
12 |
+
|
13 |
+
llm_loader, qa_chain = app_init()
|
14 |
+
|
15 |
+
|
16 |
+
class MyCustomHandler(BaseCallbackHandler):
|
17 |
+
def __init__(self):
|
18 |
+
self.reset()
|
19 |
+
|
20 |
+
def reset(self):
|
21 |
+
self.texts = []
|
22 |
+
|
23 |
+
def get_standalone_question(self) -> str:
|
24 |
+
return self.texts[0].strip() if len(self.texts) > 0 else None
|
25 |
+
|
26 |
+
def on_llm_end(self, response: LLMResult, **kwargs) -> None:
|
27 |
+
"""Run when chain ends running."""
|
28 |
+
print("\n<on_llm_end>")
|
29 |
+
# print(response)
|
30 |
+
self.texts.append(response.generations[0][0].text)
|
31 |
+
|
32 |
+
|
33 |
+
num_of_test_runs = 1
|
34 |
+
chatting = len(sys.argv) > 1 and sys.argv[1] == "chat"
|
35 |
+
if len(sys.argv) > 1 and not chatting:
|
36 |
+
num_of_test_runs = int(sys.argv[1])
|
37 |
+
|
38 |
+
questions_file_path = os.environ.get("QUESTIONS_FILE_PATH")
|
39 |
+
chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") or "true"
|
40 |
+
|
41 |
+
custom_handler = MyCustomHandler()
|
42 |
+
|
43 |
+
# Chatbot loop
|
44 |
+
chat_history = []
|
45 |
+
|
46 |
+
# Open the file for reading
|
47 |
+
file = open(questions_file_path, "r")
|
48 |
+
|
49 |
+
# Read the contents of the file into a list of strings
|
50 |
+
questions = file.readlines()
|
51 |
+
for i in range(len(questions)):
|
52 |
+
questions[i] = questions[i].strip()
|
53 |
+
|
54 |
+
if num_of_test_runs > 1:
|
55 |
+
new_questions = []
|
56 |
+
|
57 |
+
for i in range(num_of_test_runs):
|
58 |
+
new_questions += questions
|
59 |
+
|
60 |
+
questions = new_questions
|
61 |
+
|
62 |
+
# Close the file
|
63 |
+
file.close()
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
questions.append("exit")
|
67 |
+
|
68 |
+
chat_start = timer()
|
69 |
+
|
70 |
+
while True:
|
71 |
+
if chatting:
|
72 |
+
query = input("Please enter your question: ")
|
73 |
+
else:
|
74 |
+
query = questions.pop(0)
|
75 |
+
|
76 |
+
query = query.strip()
|
77 |
+
if query.lower() == "exit":
|
78 |
+
break
|
79 |
+
|
80 |
+
print("\nQuestion: " + query)
|
81 |
+
custom_handler.reset()
|
82 |
+
|
83 |
+
start = timer()
|
84 |
+
result = qa_chain.call_chain(
|
85 |
+
{"question": query, "chat_history": chat_history},
|
86 |
+
custom_handler,
|
87 |
+
None,
|
88 |
+
True,
|
89 |
+
)
|
90 |
+
end = timer()
|
91 |
+
print(f"Completed in {end - start:.3f}s")
|
92 |
+
|
93 |
+
if chat_history_enabled == "true":
|
94 |
+
chat_history.append((query, result["answer"]))
|
95 |
+
|
96 |
+
print_llm_response(result)
|
97 |
+
|
98 |
+
chat_end = timer()
|
99 |
+
total_time = chat_end - chat_start
|
100 |
+
print(f"Total time used: {total_time:.3f} s")
|
101 |
+
print(f"Number of tokens generated: {llm_loader.streamer.total_tokens}")
|
102 |
+
print(
|
103 |
+
f"Average generation speed: {llm_loader.streamer.total_tokens / total_time:.3f} tokens/s"
|
104 |
+
)
|