""" Main RAG chain based on langchain. """ from langchain.chains import LLMChain from langchain.prompts import ( SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, PromptTemplate, ) from langchain.chains import ConversationalRetrievalChain from langchain.chains.conversation.memory import ( ConversationBufferWindowMemory, ) from langchain.chains import StuffDocumentsChain def get_cite_combine_docs_chain(llm): """Get doc chain which adds metadata to text chunks.""" # Ref: https://github.com/langchain-ai/langchain/issues/7239 # Function to format each document with an index, source, and content. def format_document(doc, index, prompt): """Format a document into a string based on a prompt template.""" # Create a dictionary with document content and metadata. base_info = { "page_content": doc.page_content, "index": index, "source": doc.metadata["source"], } # Check if any metadata is missing. missing_metadata = set(prompt.input_variables).difference(base_info) if len(missing_metadata) > 0: raise ValueError(f"Missing metadata: {list(missing_metadata)}.") # Filter only necessary variables for the prompt. document_info = {k: base_info[k] for k in prompt.input_variables} return prompt.format(**document_info) # Custom chain class to handle document combination with source indices. class StuffDocumentsWithIndexChain(StuffDocumentsChain): """Custom chain class to handle document combination with source indices.""" def _get_inputs(self, docs, **kwargs): """Overwrite _get_inputs to add metadata for text chunks.""" # Format each document and combine them. doc_strings = [ format_document(doc, i, self.document_prompt) for i, doc in enumerate(docs, 1) ] # Filter only relevant input variables for the LLM chain prompt. inputs = { k: v for k, v in kwargs.items() if k in self.llm_chain.prompt.input_variables } inputs[self.document_variable_name] = self.document_separator.join( doc_strings ) return inputs # Main prompt for RAG chain with citation # Ref: https://huggingface.co/spaces/Ekimetrics/climate-question-answering/blob/main/climateqa/engine/prompts.py # Define a chat prompt with instructions for citing documents. combine_doc_prompt = PromptTemplate( input_variables=["context", "question"], template="""You are given a question and passages. Provide a clear and structured Helpful Answer based on the passages provided, the context and the guidelines. Guidelines: - If the passages have useful facts or numbers, use them in your answer. - When you use information from a passage, mention where it came from by using format [[i]] at the end of the sentence. i stands for the paper index of the document. - Do not cite the passage in a style like 'passage i', always use format [[i]] where i stands for the passage index of the document. - Do not use the sentence such as 'Doc i says ...' or '... in Doc i' or 'Passage i ...' to say where information came from. - If the same thing is said in more than one document, you can mention all of them like this: [[i]], [[j]], [[k]]. - Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation. - If it makes sense, use bullet points and lists to make your answers easier to understand. - You do not need to use every passage. Only use the ones that help answer the question. - If the documents do not have the information needed to answer the question, just say you do not have enough information. - If the passage is the caption of a picture, you can still use it as part of your answer as any other document. ----------------------- Passages: {context} ----------------------- Question: {question} Helpful Answer with format citations:""", ) # Initialize the custom chain with a specific document format. combine_docs_chain = StuffDocumentsWithIndexChain( llm_chain=LLMChain( llm=llm, prompt=combine_doc_prompt, ), document_prompt=PromptTemplate( input_variables=["index", "source", "page_content"], template="[[{index}]]\nsource: {source}:\n{page_content}", ), document_variable_name="context", ) return combine_docs_chain class RAGChain: """Main RAG chain.""" def __init__( self, memory_key="chat_history", output_key="answer", return_messages=True ): self.memory_key = memory_key self.output_key = output_key self.return_messages = return_messages def create(self, retriever, llm, add_citation=False): """Create a rag chain instance.""" # Memory is kept for later support of conversational chat memory = ConversationBufferWindowMemory( # Or ConversationBufferMemory k=2, memory_key=self.memory_key, return_messages=self.return_messages, output_key=self.output_key, ) # Ref: https://github.com/langchain-ai/langchain/issues/4608 conversation_chain = ConversationalRetrievalChain.from_llm( llm=llm, retriever=retriever, memory=memory, return_source_documents=True, rephrase_question=False, # disable rephrase, for test purpose get_chat_history=lambda x: x, # return_generated_question=True, # for debug # combine_docs_chain_kwargs={"prompt": PROMPT}, # additional prompt control # condense_question_prompt=CONDENSE_QUESTION_PROMPT, # additional prompt control ) # Add citation, ATTENTION: experimental if add_citation: cite_combine_docs_chain = get_cite_combine_docs_chain(llm) conversation_chain.combine_docs_chain = cite_combine_docs_chain return conversation_chain