dh-mc commited on
Commit
9e7327c
·
1 Parent(s): 7190ef8

added llm_qa_chain_with_memory.py

Browse files
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  *.out
 
2
  pdfs/
3
  .vscode/
4
 
 
1
  *.out
2
+ *.log
3
  pdfs/
4
  .vscode/
5
 
Makefile CHANGED
@@ -5,8 +5,11 @@ start:
5
  test:
6
  python qa_chain_test.py
7
 
 
 
 
8
  chat:
9
- python qa_chain_test.py chat
10
 
11
  ingest:
12
  python ingest.py
 
5
  test:
6
  python qa_chain_test.py
7
 
8
+ long-test:
9
+ python qa_chain_with_memory_test.py 100
10
+
11
  chat:
12
+ python qa_chain_with_memory_test.py chat
13
 
14
  ingest:
15
  python ingest.py
app.py CHANGED
@@ -8,6 +8,8 @@ from timeit import default_timer as timer
8
 
9
  import gradio as gr
10
 
 
 
11
  from app_modules.init import app_init
12
  from app_modules.utils import print_llm_response
13
 
@@ -29,10 +31,13 @@ href = (
29
  )
30
 
31
  title = "Chat with PCI DSS v4"
32
- examples = [
33
- "What's PCI DSS?",
34
- "Can you summarize the changes made from PCI DSS version 3.2.1 to version 4.0?",
35
- ]
 
 
 
36
 
37
  description = f"""\
38
  <div align="left">
 
8
 
9
  import gradio as gr
10
 
11
+ os.environ["USER_CONVERSATION_SUMMARY_BUFFER_MEMORY"] = "true"
12
+
13
  from app_modules.init import app_init
14
  from app_modules.utils import print_llm_response
15
 
 
31
  )
32
 
33
  title = "Chat with PCI DSS v4"
34
+
35
+ questions_file_path = os.environ.get("QUESTIONS_FILE_PATH")
36
+
37
+ # Open the file for reading
38
+ with open(questions_file_path, "r") as file:
39
+ examples = file.readlines()
40
+ examples = [example.strip() for example in examples]
41
 
42
  description = f"""\
43
  <div align="left">
app_modules/init.py CHANGED
@@ -1,4 +1,5 @@
1
  """Main entrypoint for the app."""
 
2
  import os
3
  from timeit import default_timer as timer
4
  from typing import List, Optional
@@ -9,7 +10,6 @@ from langchain.vectorstores.chroma import Chroma
9
  from langchain.vectorstores.faiss import FAISS
10
 
11
  from app_modules.llm_loader import LLMLoader
12
- from app_modules.llm_qa_chain import QAChain
13
  from app_modules.utils import get_device_types, init_settings
14
 
15
  found_dotenv = find_dotenv(".env")
@@ -27,6 +27,15 @@ if os.environ.get("LANGCHAIN_DEBUG") == "true":
27
 
28
  langchain.debug = True
29
 
 
 
 
 
 
 
 
 
 
30
 
31
  def app_init():
32
  # https://github.com/huggingface/transformers/issues/17611
 
1
  """Main entrypoint for the app."""
2
+
3
  import os
4
  from timeit import default_timer as timer
5
  from typing import List, Optional
 
10
  from langchain.vectorstores.faiss import FAISS
11
 
12
  from app_modules.llm_loader import LLMLoader
 
13
  from app_modules.utils import get_device_types, init_settings
14
 
15
  found_dotenv = find_dotenv(".env")
 
27
 
28
  langchain.debug = True
29
 
30
+ if os.environ.get("USER_CONVERSATION_SUMMARY_BUFFER_MEMORY") == "true":
31
+ from app_modules.llm_qa_chain_with_memory import QAChain
32
+
33
+ print("using llm_qa_chain_with_memory")
34
+ else:
35
+ from app_modules.llm_qa_chain import QAChain
36
+
37
+ print("using llm_qa_chain")
38
+
39
 
40
  def app_init():
41
  # https://github.com/huggingface/transformers/issues/17611
app_modules/llm_qa_chain_with_memory.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chains import ConversationalRetrievalChain
2
+ from langchain.chains.base import Chain
3
+ from langchain.memory import ConversationSummaryBufferMemory
4
+
5
+ from app_modules.llm_inference import LLMInference
6
+
7
+
8
+ class QAChain(LLMInference):
9
+ def __init__(self, vectorstore, llm_loader):
10
+ super().__init__(llm_loader)
11
+ self.vectorstore = vectorstore
12
+
13
+ def create_chain(self) -> Chain:
14
+ memory = ConversationSummaryBufferMemory(
15
+ llm=self.llm_loader.llm,
16
+ output_key="answer",
17
+ memory_key="chat_history",
18
+ max_token_limit=1024,
19
+ return_messages=True,
20
+ )
21
+ qa = ConversationalRetrievalChain.from_llm(
22
+ self.llm_loader.llm,
23
+ memory=memory,
24
+ chain_type="stuff",
25
+ retriever=self.vectorstore.as_retriever(
26
+ search_kwargs=self.llm_loader.search_kwargs
27
+ ),
28
+ get_chat_history=lambda h: h,
29
+ return_source_documents=True,
30
+ )
31
+
32
+ return qa
app_modules/utils.py CHANGED
@@ -85,6 +85,10 @@ def print_llm_response(llm_response):
85
  source["page_content"] if "page_content" in source else source.page_content
86
  )
87
 
 
 
 
 
88
 
89
  def get_device_types():
90
  print("Running on: ", platform.platform())
 
85
  source["page_content"] if "page_content" in source else source.page_content
86
  )
87
 
88
+ if "chat_history" in llm_response:
89
+ print("\nChat History:")
90
+ print(llm_response["chat_history"])
91
+
92
 
93
  def get_device_types():
94
  print("Running on: ", platform.platform())
qa_chain_with_memory_test.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from timeit import default_timer as timer
4
+
5
+ from langchain.callbacks.base import BaseCallbackHandler
6
+ from langchain.schema import LLMResult
7
+
8
+ os.environ["USER_CONVERSATION_SUMMARY_BUFFER_MEMORY"] = "true"
9
+
10
+ from app_modules.init import app_init
11
+ from app_modules.utils import print_llm_response
12
+
13
+ llm_loader, qa_chain = app_init()
14
+
15
+
16
+ class MyCustomHandler(BaseCallbackHandler):
17
+ def __init__(self):
18
+ self.reset()
19
+
20
+ def reset(self):
21
+ self.texts = []
22
+
23
+ def get_standalone_question(self) -> str:
24
+ return self.texts[0].strip() if len(self.texts) > 0 else None
25
+
26
+ def on_llm_end(self, response: LLMResult, **kwargs) -> None:
27
+ """Run when chain ends running."""
28
+ print("\n<on_llm_end>")
29
+ # print(response)
30
+ self.texts.append(response.generations[0][0].text)
31
+
32
+
33
+ num_of_test_runs = 1
34
+ chatting = len(sys.argv) > 1 and sys.argv[1] == "chat"
35
+ if len(sys.argv) > 1 and not chatting:
36
+ num_of_test_runs = int(sys.argv[1])
37
+
38
+ questions_file_path = os.environ.get("QUESTIONS_FILE_PATH")
39
+ chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") or "true"
40
+
41
+ custom_handler = MyCustomHandler()
42
+
43
+ # Chatbot loop
44
+ chat_history = []
45
+
46
+ # Open the file for reading
47
+ file = open(questions_file_path, "r")
48
+
49
+ # Read the contents of the file into a list of strings
50
+ questions = file.readlines()
51
+ for i in range(len(questions)):
52
+ questions[i] = questions[i].strip()
53
+
54
+ if num_of_test_runs > 1:
55
+ new_questions = []
56
+
57
+ for i in range(num_of_test_runs):
58
+ new_questions += questions
59
+
60
+ questions = new_questions
61
+
62
+ # Close the file
63
+ file.close()
64
+
65
+ if __name__ == "__main__":
66
+ questions.append("exit")
67
+
68
+ chat_start = timer()
69
+
70
+ while True:
71
+ if chatting:
72
+ query = input("Please enter your question: ")
73
+ else:
74
+ query = questions.pop(0)
75
+
76
+ query = query.strip()
77
+ if query.lower() == "exit":
78
+ break
79
+
80
+ print("\nQuestion: " + query)
81
+ custom_handler.reset()
82
+
83
+ start = timer()
84
+ result = qa_chain.call_chain(
85
+ {"question": query, "chat_history": chat_history},
86
+ custom_handler,
87
+ None,
88
+ True,
89
+ )
90
+ end = timer()
91
+ print(f"Completed in {end - start:.3f}s")
92
+
93
+ if chat_history_enabled == "true":
94
+ chat_history.append((query, result["answer"]))
95
+
96
+ print_llm_response(result)
97
+
98
+ chat_end = timer()
99
+ total_time = chat_end - chat_start
100
+ print(f"Total time used: {total_time:.3f} s")
101
+ print(f"Number of tokens generated: {llm_loader.streamer.total_tokens}")
102
+ print(
103
+ f"Average generation speed: {llm_loader.streamer.total_tokens / total_time:.3f} tokens/s"
104
+ )