Kadi-IAM commited on
Commit
646f8c2
·
1 Parent(s): e607fab

Remove not used codes

Browse files
Files changed (6) hide show
  1. app.py +5 -6
  2. embeddings.py +1 -1
  3. llms.py +8 -77
  4. preprocess_documents.py +59 -0
  5. models.py → ragchain.py +28 -201
  6. vectorestores.py +1 -1
app.py CHANGED
@@ -5,17 +5,16 @@ from pathlib import Path
5
  from dotenv import load_dotenv
6
  import pickle
7
 
8
- from llms import get_groq_chat
9
 
10
  import gradio as gr
11
 
12
  from huggingface_hub import login
 
13
 
 
14
  from documents import load_pdf_as_docs, load_xml_as_docs
15
-
16
  from vectorestores import get_faiss_vectorestore
17
 
18
- from langchain.vectorstores import FAISS
19
 
20
  # For debug
21
  # from langchain.globals import set_debug
@@ -100,7 +99,7 @@ llm = get_groq_chat(model_name="llama-3.1-70b-versatile")
100
 
101
 
102
  # # # Create conversation qa chain (Note: conversation is not supported yet)
103
- from models import RAGChain
104
 
105
  rag_chain = RAGChain()
106
  lisa_qa_conversation = rag_chain.create(rerank_retriever, llm, add_citation=True)
@@ -213,8 +212,8 @@ def postprocess_citation(text, source_docs):
213
  # print(f"source ids by re: {source_ids}")
214
  # source_ids = re.findall(r"\[\[(.*?)\]\]", text) # List[Char]
215
  aligned_source_ids = list(map(lambda x: int(x) - 1, source_ids)) # shift index-1
216
- # print(f"souce ids generated by llm: {aligned_source_ids}")
217
- # Filter fake souce ids as LLM might generate false source ids
218
  candidate_source_ids = list(range(len(source_docs)))
219
  filtered_source_ids = set(
220
  [i for i in aligned_source_ids if i in candidate_source_ids]
 
5
  from dotenv import load_dotenv
6
  import pickle
7
 
 
8
 
9
  import gradio as gr
10
 
11
  from huggingface_hub import login
12
+ from langchain.vectorstores import FAISS
13
 
14
+ from llms import get_groq_chat
15
  from documents import load_pdf_as_docs, load_xml_as_docs
 
16
  from vectorestores import get_faiss_vectorestore
17
 
 
18
 
19
  # For debug
20
  # from langchain.globals import set_debug
 
99
 
100
 
101
  # # # Create conversation qa chain (Note: conversation is not supported yet)
102
+ from ragchain import RAGChain
103
 
104
  rag_chain = RAGChain()
105
  lisa_qa_conversation = rag_chain.create(rerank_retriever, llm, add_citation=True)
 
212
  # print(f"source ids by re: {source_ids}")
213
  # source_ids = re.findall(r"\[\[(.*?)\]\]", text) # List[Char]
214
  aligned_source_ids = list(map(lambda x: int(x) - 1, source_ids)) # shift index-1
215
+ # print(f"source ids generated by llm: {aligned_source_ids}")
216
+ # Filter fake source ids as LLM might generate false source ids
217
  candidate_source_ids = list(range(len(source_docs)))
218
  filtered_source_ids = set(
219
  [i for i in aligned_source_ids if i in candidate_source_ids]
embeddings.py CHANGED
@@ -20,7 +20,7 @@ def get_hf_embeddings(model_name=None):
20
 
21
 
22
  def get_jinaai_embeddings(model_name="jinaai/jina-embeddings-v2-base-en", device="auto"):
23
- """Get jinnai embedding."""
24
 
25
  # device: cpu or cuda
26
  if device == "auto":
 
20
 
21
 
22
  def get_jinaai_embeddings(model_name="jinaai/jina-embeddings-v2-base-en", device="auto"):
23
+ """Get jinaai embedding."""
24
 
25
  # device: cpu or cuda
26
  if device == "auto":
llms.py CHANGED
@@ -1,44 +1,17 @@
1
  # from langchain import HuggingFaceHub, LLMChain
2
- from langchain.chains import LLMChain
3
  from langchain.llms import HuggingFacePipeline
4
  from transformers import (
5
  AutoModelForCausalLM,
6
  AutoTokenizer,
7
  pipeline,
8
- T5Tokenizer,
9
- T5ForConditionalGeneration,
10
- GPT2TokenizerFast,
11
  )
12
  from transformers import LlamaForCausalLM, AutoModelForCausalLM, LlamaTokenizer
13
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
14
- from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, PromptTemplate
15
  from langchain_groq import ChatGroq
16
 
17
 
18
- # model_path = "/mnt/localstorage/yinghan/llm/orca_mini_v3_13b"
19
- # model = LlamaForCausalLM.from_pretrained(model_path, device_map="auto")#, load_in_8bit=True)
20
- # tokenizer = AutoTokenizer.from_pretrained(model_path)
21
  from langchain.chat_models import ChatOpenAI
22
- # from langchain_openai import ChatOpenAI
23
- # from langchain_openai import ChatOpenAI
24
- from langchain.embeddings.openai import OpenAIEmbeddings
25
- from langchain.embeddings import HuggingFaceEmbeddings
26
- from langchain.vectorstores import Chroma
27
- from langchain.text_splitter import (
28
- CharacterTextSplitter,
29
- RecursiveCharacterTextSplitter,
30
- )
31
- from langchain.document_loaders import TextLoader, UnstructuredHTMLLoader, PyPDFLoader
32
- from langchain.chains.retrieval_qa.base import RetrievalQA
33
- from langchain.llms import HuggingFaceHub
34
- from dotenv import load_dotenv
35
  from langchain.llms import HuggingFaceTextGenInference
36
- from langchain.chains.question_answering import load_qa_chain
37
- from langchain.chains import ConversationalRetrievalChain
38
- from langchain.chains.conversation.memory import (
39
- ConversationBufferMemory,
40
- ConversationBufferWindowMemory,
41
- )
42
 
43
 
44
  def get_llm_hf_online(inference_api_url=""):
@@ -50,20 +23,12 @@ def get_llm_hf_online(inference_api_url=""):
50
  )
51
 
52
  llm = HuggingFaceTextGenInference(
53
- # cache=None, # Optional: Cache verwenden oder nicht
54
  verbose=True, # Provides detailed logs of operation
55
- # callbacks=[StreamingStdOutCallbackHandler()], # Handeling Streams
56
  max_new_tokens=1024, # Maximum number of token that can be generated.
57
- # top_k=2, # Die Anzahl der Top-K Tokens, die beim Generieren berücksichtigt werden sollen
58
  top_p=0.95, # Threshold for controlling randomness in text generation process.
59
- typical_p=0.95, #
60
- temperature=0.1, # For choosing probable words.
61
- # repetition_penalty=None, # Wiederholungsstrafe beim Generieren
62
- # truncate=None, # Schneidet die Eingabe-Tokens auf die gegebene Größe
63
- # stop_sequences=None, # Eine Liste von Stop-Sequenzen beim Generieren
64
- inference_server_url=inference_api_url, # URL des Inferenzservers
65
  timeout=10, # Timeout for connection with the url
66
- # streaming=True, # Streaming the answer
67
  )
68
 
69
  return llm
@@ -72,12 +37,9 @@ def get_llm_hf_online(inference_api_url=""):
72
  def get_llm_hf_local(model_path):
73
  """Get local LLM."""
74
 
75
- # model_path = "/mnt/localstorage/yinghan/llm/orca_mini_v3_13b"
76
- # model_path = "/mnt/localstorage/yinghan/llm/zephyr-7b-beta"
77
- model = LlamaForCausalLM.from_pretrained( # or AutoModelForCausalLM. TODO: which is better? what's difference?
78
  model_path, device_map="auto"
79
- ) # , load_in_8bit=True)
80
- # model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")#, load_in_8bit=True) # which is better?
81
  tokenizer = AutoTokenizer.from_pretrained(model_path)
82
 
83
  # print('making a pipeline...')
@@ -86,8 +48,8 @@ def get_llm_hf_local(model_path):
86
  "text-generation",
87
  model=model,
88
  tokenizer=tokenizer,
89
- max_new_tokens=1024, # need better set
90
- model_kwargs={"temperature": 0.1}, # need better set
91
  )
92
  llm = HuggingFacePipeline(pipeline=pipe)
93
 
@@ -95,51 +57,20 @@ def get_llm_hf_local(model_path):
95
 
96
 
97
 
98
- def get_llm_openai_chat(model_name, inference_server_url, langfuse_callback=None):
99
  """Get openai-like LLM."""
100
 
101
- # Some defaults
102
- # chat_model_name = "openchat/openchat_3.5"
103
- # inference_server_url = "http://localhost:8080/v1"
104
  llm = ChatOpenAI(
105
  model=model_name,
106
  openai_api_key="EMPTY",
107
  openai_api_base=inference_server_url,
108
  max_tokens=1024, # better setting?
109
- temperature=0, # default 0.7, better setting?
110
- # callbacks=[langfuse_callback],
111
  )
112
 
113
- # The following is not required for builing normal llm
114
- # use the Ragas LangchainLLM wrapper to create a RagasLLM instance
115
- # vllm = LangchainLLM(llm=chat)
116
- # return vllm
117
  return llm
118
 
119
 
120
- def get_chat_vllm(model_name, inference_server_url, langfuse_callback=None):
121
-
122
- # to fix
123
- # Create vLLM Langchain instance
124
-
125
- # Some defaults
126
- # chat_model_name = "openchat/openchat_3.5"
127
- # inference_server_url = "http://localhost:8080/v1"
128
- chat = ChatOpenAI(
129
- model=model_name,
130
- openai_api_key="EMPTY",
131
- openai_api_base=inference_server_url,
132
- max_tokens=512, # better setting?
133
- temperature=0.1, # default 0.7, better setting?
134
- # callbacks=[langfuse_callback],
135
- )
136
-
137
- # The following is not required for builing normal llm
138
- # use the Ragas LangchainLLM wrapper to create a RagasLLM instance
139
- # vllm = LangchainLLM(llm=chat)
140
- # return vllm
141
- return chat
142
-
143
  def get_groq_chat(model_name="llama-3.1-70b-versatile"):
144
 
145
  llm = ChatGroq(temperature=0, model_name=model_name)
 
1
  # from langchain import HuggingFaceHub, LLMChain
 
2
  from langchain.llms import HuggingFacePipeline
3
  from transformers import (
4
  AutoModelForCausalLM,
5
  AutoTokenizer,
6
  pipeline,
 
 
 
7
  )
8
  from transformers import LlamaForCausalLM, AutoModelForCausalLM, LlamaTokenizer
9
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 
10
  from langchain_groq import ChatGroq
11
 
12
 
 
 
 
13
  from langchain.chat_models import ChatOpenAI
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from langchain.llms import HuggingFaceTextGenInference
 
 
 
 
 
 
15
 
16
 
17
  def get_llm_hf_online(inference_api_url=""):
 
23
  )
24
 
25
  llm = HuggingFaceTextGenInference(
 
26
  verbose=True, # Provides detailed logs of operation
 
27
  max_new_tokens=1024, # Maximum number of token that can be generated.
 
28
  top_p=0.95, # Threshold for controlling randomness in text generation process.
29
+ temperature=0.1,
30
+ inference_server_url=inference_api_url,
 
 
 
 
31
  timeout=10, # Timeout for connection with the url
 
32
  )
33
 
34
  return llm
 
37
  def get_llm_hf_local(model_path):
38
  """Get local LLM."""
39
 
40
+ model = LlamaForCausalLM.from_pretrained(
 
 
41
  model_path, device_map="auto"
42
+ )
 
43
  tokenizer = AutoTokenizer.from_pretrained(model_path)
44
 
45
  # print('making a pipeline...')
 
48
  "text-generation",
49
  model=model,
50
  tokenizer=tokenizer,
51
+ max_new_tokens=1024, # better setting?
52
+ model_kwargs={"temperature": 0.1}, # better setting?
53
  )
54
  llm = HuggingFacePipeline(pipeline=pipe)
55
 
 
57
 
58
 
59
 
60
+ def get_llm_openai_chat(model_name, inference_server_url):
61
  """Get openai-like LLM."""
62
 
 
 
 
63
  llm = ChatOpenAI(
64
  model=model_name,
65
  openai_api_key="EMPTY",
66
  openai_api_base=inference_server_url,
67
  max_tokens=1024, # better setting?
68
+ temperature=0,
 
69
  )
70
 
 
 
 
 
71
  return llm
72
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def get_groq_chat(model_name="llama-3.1-70b-versatile"):
75
 
76
  llm = ChatGroq(temperature=0, model_name=model_name)
preprocess_documents.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Load and parse files (pdf) in the data/documents and save cached pkl files.
3
+ """
4
+
5
+ import os
6
+ import pickle
7
+
8
+ from dotenv import load_dotenv
9
+
10
+
11
+ from huggingface_hub import login
12
+
13
+ from documents import load_pdf_as_docs, get_doc_chunks
14
+ from embeddings import get_jinaai_embeddings
15
+
16
+
17
+ # Load and set env variables
18
+ load_dotenv()
19
+
20
+ # Set huggingface api for downloading embedding model
21
+ HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
22
+ login(HUGGINGFACEHUB_API_TOKEN)
23
+
24
+
25
+ def save_to_pickle(obj, filename):
26
+ with open(filename, "wb") as file:
27
+ pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)
28
+
29
+
30
+ # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
31
+ database_root = "./data/db"
32
+ document_path = "./data/documents"
33
+
34
+ # Parse pdf as "Documents" instances and save as "docs.pkl"
35
+ docs = load_pdf_as_docs(document_path)
36
+ save_to_pickle(docs, os.path.join(database_root, "docs.pkl"))
37
+
38
+ # Get text chunks and save as "docs_chunks.pkl"
39
+ document_chunks = get_doc_chunks(docs)
40
+ save_to_pickle(docs, os.path.join(database_root, "docs_chunks.pkl"))
41
+
42
+ embeddings = get_jinaai_embeddings(device="auto")
43
+
44
+ # Create and save vectorstore
45
+ from vectorestores import get_faiss_vectorestore
46
+
47
+ vectorstore = get_faiss_vectorestore(embeddings)
48
+
49
+ # Create retrievers
50
+ from retrievers import get_parent_doc_retriever
51
+
52
+ # Get parent doc (small-to-big) retriever and save as "docstore.pkl"
53
+ parent_doc_retriever = get_parent_doc_retriever(
54
+ docs,
55
+ vectorstore,
56
+ save_path_root=database_root,
57
+ save_vectorstore=True,
58
+ save_docstore=True,
59
+ )
models.py → ragchain.py RENAMED
@@ -1,30 +1,39 @@
1
  from langchain.chains import LLMChain
2
 
3
- from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, PromptTemplate
 
 
 
 
 
4
 
5
  from langchain.chains import ConversationalRetrievalChain
6
  from langchain.chains.conversation.memory import (
7
- ConversationBufferMemory,
8
  ConversationBufferWindowMemory,
9
  )
10
 
11
 
12
- from langchain.chains import RetrievalQAWithSourcesChain, StuffDocumentsChain
 
13
 
14
  def get_cite_combine_docs_chain(llm):
15
-
16
  # Ref: https://github.com/langchain-ai/langchain/issues/7239
17
  # Function to format each document with an index, source, and content.
18
  def format_document(doc, index, prompt):
19
  """Format a document into a string based on a prompt template."""
20
  # Create a dictionary with document content and metadata.
21
- base_info = {"page_content": doc.page_content, "index": index, "source": doc.metadata["source"]}
22
-
 
 
 
 
23
  # Check if any metadata is missing.
24
  missing_metadata = set(prompt.input_variables).difference(base_info)
25
  if len(missing_metadata) > 0:
26
  raise ValueError(f"Missing metadata: {list(missing_metadata)}.")
27
-
28
  # Filter only necessary variables for the prompt.
29
  document_info = {k: base_info[k] for k in prompt.input_variables}
30
  return prompt.format(**document_info)
@@ -37,10 +46,16 @@ def get_cite_combine_docs_chain(llm):
37
  format_document(doc, i, self.document_prompt)
38
  for i, doc in enumerate(docs, 1)
39
  ]
40
-
41
  # Filter only relevant input variables for the LLM chain prompt.
42
- inputs = {k: v for k, v in kwargs.items() if k in self.llm_chain.prompt.input_variables}
43
- inputs[self.document_variable_name] = self.document_separator.join(doc_strings)
 
 
 
 
 
 
44
  return inputs
45
 
46
  # Ref: https://huggingface.co/spaces/Ekimetrics/climate-question-answering/blob/main/climateqa/engine/prompts.py
@@ -68,7 +83,7 @@ def get_cite_combine_docs_chain(llm):
68
  -----------------------
69
  Question: {question}
70
 
71
- Helpful Answer with format citations:"""
72
  )
73
 
74
  # Initialize the custom chain with a specific document format.
@@ -83,194 +98,8 @@ def get_cite_combine_docs_chain(llm):
83
  ),
84
  document_variable_name="context",
85
  )
86
-
87
- return combine_docs_chain
88
-
89
-
90
- class ConversationChainFactory:
91
- def __init__(
92
- self, memory_key="chat_history", output_key="answer", return_messages=True
93
- ):
94
- self.memory_key = memory_key
95
- self.output_key = output_key
96
- self.return_messages = return_messages
97
-
98
- def create(self, retriever, llm):
99
- memory = ConversationBufferWindowMemory( # ConversationBufferMemory(
100
- memory_key=self.memory_key,
101
- return_messages=self.return_messages,
102
- output_key=self.output_key,
103
- )
104
-
105
- # prompt:
106
- # https://github.com/langchain-ai/langchain/issues/6530
107
-
108
-
109
- prompt_template = """You are a helpful research assistant. Use the following pieces of context to answer the question at the end.
110
- Please ignore the contexts if they are not related to the question. If you don't know the answer, just say that you don't know,
111
- don't try to make up an answer.
112
-
113
- {context}
114
-
115
- Question: {question}
116
-
117
- Helpful Answer:"""
118
- PROMPT = PromptTemplate(
119
- template=prompt_template, input_variables=["context", "question"]
120
- )
121
-
122
- # Rephrase question based on history
123
- # https://www.paepper.com/blog/posts/how-to-build-a-chatbot-out-of-your-website-content/
124
- # tested: Be careful with the technical abbreviations and items, do not modify them unless necessary -> worse
125
- # You are a helpful research assistant. -> worse, tend to expand question
126
- # My testing prompt
127
- # _template = """Given the following conversation and a follow up question,
128
- # rephrase the follow up question to be a standalone question only when it is necessary.
129
- # If the conversation is not related to the question, do not rephrase the follow up question
130
- # and just put the standalone question exactly the same as the original follow up question.
131
- # The standalone question should be in its original language, which is usually english.
132
-
133
- # Chat History: {chat_history}
134
-
135
- # Follow Up Question: {question}
136
-
137
- # Standalone Question:"""
138
-
139
- # Type 2: https://github.com/langchain-ai/langchain/issues/4076
140
- _template = """Return text in the original language of the follow up question.
141
- If the follow up question does not need context, return the exact same text back.
142
- Never rephrase the follow up question given the chat history unless the follow up question needs context.
143
-
144
- Chat History: {chat_history}
145
-
146
- Follow Up Question: {question}
147
-
148
- Standalone Question:"""
149
- CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
150
- # or just turn if off, see https://github.com/langchain-ai/langchain/issues/4076
151
-
152
- # Change prompt to context-based QA
153
- # system_template = """You are a professional scientist. Use the following pieces of context to answer the users question.
154
- # Please ignore the contexts if they are not related to the question. If you don't know the answer, just say that you don't know, don't try to make up an answer.
155
- # ----------------
156
- # {context}"""
157
- # messages = [
158
- # SystemMessagePromptTemplate.from_template(system_template),
159
- # HumanMessagePromptTemplate.from_template("{question}"),
160
- # ]
161
- # QA_CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
162
-
163
- # https://github.com/langchain-ai/langchain/issues/4608
164
- conversation_chain = ConversationalRetrievalChain.from_llm(
165
- llm=llm,
166
- retriever=retriever,
167
- memory=memory,
168
- return_source_documents=True,
169
- # return_generated_question=True, # for debug
170
- rephrase_question=False, # Disable rephrase, for test purpose
171
- get_chat_history=lambda x: x,
172
- # verbose=True,
173
- # combine_docs_chain_kwargs={"prompt": PROMPT},
174
- # condense_question_prompt=CONDENSE_QUESTION_PROMPT,
175
- )
176
-
177
-
178
 
179
- return conversation_chain
180
-
181
-
182
- class ConversationChainFactoryDev:
183
- def __init__(
184
- self, memory_key="chat_history", output_key="answer", return_messages=True
185
- ):
186
- self.memory_key = memory_key
187
- self.output_key = output_key
188
- self.return_messages = return_messages
189
-
190
- def create(self, retriever, llm):
191
- memory = ConversationBufferWindowMemory( # ConversationBufferMemory(
192
- memory_key=self.memory_key,
193
- return_messages=self.return_messages,
194
- output_key=self.output_key,
195
- )
196
-
197
- # prompt:
198
- # https://github.com/langchain-ai/langchain/issues/6530
199
-
200
-
201
- prompt_template = """You are a helpful research assistant. Use the following pieces of context to answer the question at the end.
202
- Please ignore the contexts if they are not related to the question. If you don't know the answer, just say that you don't know,
203
- don't try to make up an answer.
204
-
205
- {context}
206
-
207
- Question: {question}
208
-
209
- Helpful Answer:"""
210
- PROMPT = PromptTemplate(
211
- template=prompt_template, input_variables=["context", "question"]
212
- )
213
-
214
- # Rephrase question based on history
215
- # https://www.paepper.com/blog/posts/how-to-build-a-chatbot-out-of-your-website-content/
216
- # tested: Be careful with the technical abbreviations and items, do not modify them unless necessary -> worse
217
- # You are a helpful research assistant. -> worse, tend to expand question
218
- # My testing prompt
219
- # _template = """Given the following conversation and a follow up question,
220
- # rephrase the follow up question to be a standalone question only when it is necessary.
221
- # If the conversation is not related to the question, do not rephrase the follow up question
222
- # and just put the standalone question exactly the same as the original follow up question.
223
- # The standalone question should be in its original language, which is usually english.
224
-
225
- # Chat History: {chat_history}
226
-
227
- # Follow Up Question: {question}
228
-
229
- # Standalone Question:"""
230
-
231
- # Type 2: https://github.com/langchain-ai/langchain/issues/4076
232
- _template = """Return text in the original language of the follow up question.
233
- If the follow up question does not need context, return the exact same text back.
234
- Never rephrase the follow up question given the chat history unless the follow up question needs context.
235
-
236
- Chat History: {chat_history}
237
-
238
- Follow Up Question: {question}
239
-
240
- Standalone Question:"""
241
- CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
242
- # or just turn if off, see https://github.com/langchain-ai/langchain/issues/4076
243
-
244
- # Change prompt to context-based QA
245
- # system_template = """You are a professional scientist. Use the following pieces of context to answer the users question.
246
- # Please ignore the contexts if they are not related to the question. If you don't know the answer, just say that you don't know, don't try to make up an answer.
247
- # ----------------
248
- # {context}"""
249
- # messages = [
250
- # SystemMessagePromptTemplate.from_template(system_template),
251
- # HumanMessagePromptTemplate.from_template("{question}"),
252
- # ]
253
- # QA_CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
254
-
255
- # https://github.com/langchain-ai/langchain/issues/4608
256
-
257
-
258
-
259
- conversation_chain = ConversationalRetrievalChain.from_llm(
260
- llm=llm,
261
- retriever=retriever,
262
- memory=memory,
263
- return_source_documents=True,
264
- # return_generated_question=True, # for debug
265
- rephrase_question=False, # Disable rephrase, for test purpose
266
- get_chat_history=lambda x: x,
267
- # verbose=True,
268
- # combine_docs_chain_kwargs={"prompt": PROMPT},
269
- # condense_question_prompt=CONDENSE_QUESTION_PROMPT,
270
- )
271
-
272
-
273
- return conversation_chain
274
 
275
 
276
  class RAGChain:
@@ -302,12 +131,10 @@ class RAGChain:
302
  # combine_docs_chain_kwargs={"prompt": PROMPT}, # additional prompt control
303
  # condense_question_prompt=CONDENSE_QUESTION_PROMPT, # additional prompt control
304
  )
305
-
306
  # Add citation, ATTENTION: experimental
307
  if add_citation:
308
- # from models import get_cite_combine_docs_chain
309
  cite_combine_docs_chain = get_cite_combine_docs_chain(llm)
310
  conversation_chain.combine_docs_chain = cite_combine_docs_chain
311
 
312
  return conversation_chain
313
-
 
1
  from langchain.chains import LLMChain
2
 
3
+ from langchain.prompts import (
4
+ SystemMessagePromptTemplate,
5
+ HumanMessagePromptTemplate,
6
+ ChatPromptTemplate,
7
+ PromptTemplate,
8
+ )
9
 
10
  from langchain.chains import ConversationalRetrievalChain
11
  from langchain.chains.conversation.memory import (
 
12
  ConversationBufferWindowMemory,
13
  )
14
 
15
 
16
+ from langchain.chains import StuffDocumentsChain
17
+
18
 
19
  def get_cite_combine_docs_chain(llm):
20
+
21
  # Ref: https://github.com/langchain-ai/langchain/issues/7239
22
  # Function to format each document with an index, source, and content.
23
  def format_document(doc, index, prompt):
24
  """Format a document into a string based on a prompt template."""
25
  # Create a dictionary with document content and metadata.
26
+ base_info = {
27
+ "page_content": doc.page_content,
28
+ "index": index,
29
+ "source": doc.metadata["source"],
30
+ }
31
+
32
  # Check if any metadata is missing.
33
  missing_metadata = set(prompt.input_variables).difference(base_info)
34
  if len(missing_metadata) > 0:
35
  raise ValueError(f"Missing metadata: {list(missing_metadata)}.")
36
+
37
  # Filter only necessary variables for the prompt.
38
  document_info = {k: base_info[k] for k in prompt.input_variables}
39
  return prompt.format(**document_info)
 
46
  format_document(doc, i, self.document_prompt)
47
  for i, doc in enumerate(docs, 1)
48
  ]
49
+
50
  # Filter only relevant input variables for the LLM chain prompt.
51
+ inputs = {
52
+ k: v
53
+ for k, v in kwargs.items()
54
+ if k in self.llm_chain.prompt.input_variables
55
+ }
56
+ inputs[self.document_variable_name] = self.document_separator.join(
57
+ doc_strings
58
+ )
59
  return inputs
60
 
61
  # Ref: https://huggingface.co/spaces/Ekimetrics/climate-question-answering/blob/main/climateqa/engine/prompts.py
 
83
  -----------------------
84
  Question: {question}
85
 
86
+ Helpful Answer with format citations:""",
87
  )
88
 
89
  # Initialize the custom chain with a specific document format.
 
98
  ),
99
  document_variable_name="context",
100
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ return combine_docs_chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
 
105
  class RAGChain:
 
131
  # combine_docs_chain_kwargs={"prompt": PROMPT}, # additional prompt control
132
  # condense_question_prompt=CONDENSE_QUESTION_PROMPT, # additional prompt control
133
  )
134
+
135
  # Add citation, ATTENTION: experimental
136
  if add_citation:
 
137
  cite_combine_docs_chain = get_cite_combine_docs_chain(llm)
138
  conversation_chain.combine_docs_chain = cite_combine_docs_chain
139
 
140
  return conversation_chain
 
vectorestores.py CHANGED
@@ -1,7 +1,7 @@
1
  from langchain.vectorstores import Chroma, FAISS
2
 
3
  def get_faiss_vectorestore(embeddings):
4
- # Add extra text to ini
5
  texts = ["LISA - Lithium Ion Solid-state Assistant"]
6
  vectorstore = FAISS.from_texts(texts, embeddings)
7
 
 
1
  from langchain.vectorstores import Chroma, FAISS
2
 
3
  def get_faiss_vectorestore(embeddings):
4
+ # Add extra text to init
5
  texts = ["LISA - Lithium Ion Solid-state Assistant"]
6
  vectorstore = FAISS.from_texts(texts, embeddings)
7