|
import json |
|
import os |
|
from typing import List |
|
import pandas as pd |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.chains.base import Chain |
|
from app_modules.llm_inference import LLMInference |
|
from app_modules.utils import CustomizedConversationSummaryBufferMemory |
|
|
|
from langchain_core.retrievers import BaseRetriever |
|
from langchain_core.documents import Document |
|
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun |
|
from langchain.globals import get_debug |
|
|
|
retrieve_from_questions_file = os.getenv("RETRIEVER_TYPE") == "questions_file" |
|
apply_chat_template_for_rag = os.getenv("APPLY_CHAT_TEMPLATE_FOR_RAG") == "true" |
|
|
|
print(f"retrieve_from_questions_file: {retrieve_from_questions_file}", flush=True) |
|
print(f"apply_chat_template_for_rag: {apply_chat_template_for_rag}", flush=True) |
|
|
|
if retrieve_from_questions_file: |
|
questions_file_path = os.getenv("QUESTIONS_FILE_PATH") |
|
questions_df = pd.read_json(questions_file_path) |
|
print(f"Questions file loaded: {questions_file_path}", flush=True) |
|
|
|
|
|
class DatasetRetriever(BaseRetriever): |
|
def _get_relevant_documents( |
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun |
|
) -> List[Document]: |
|
"""Get documents relevant to a query. |
|
Args: |
|
query: String to find relevant documents for |
|
run_manager: The callbacks handler to use |
|
Returns: |
|
List of relevant documents |
|
""" |
|
docs = [] |
|
df = questions_df |
|
|
|
|
|
filtered = df[df["question"].str.lower() == query.lower()] |
|
|
|
|
|
for i in range(len(filtered)): |
|
docs.append( |
|
Document( |
|
page_content=filtered.iloc[i]["context"], |
|
metadata={"source": filtered.iloc[i]["id"]}, |
|
) |
|
) |
|
|
|
if not docs: |
|
print(f"No documents found for query: {query}", flush=True) |
|
|
|
return docs |
|
|
|
|
|
class QAChain(LLMInference): |
|
def __init__(self, vectorstore, llm_loader): |
|
super().__init__(llm_loader) |
|
self.vectorstore = vectorstore |
|
|
|
def create_chain(self) -> Chain: |
|
if retrieve_from_questions_file: |
|
retriever = DatasetRetriever() |
|
else: |
|
retriever = self.vectorstore.as_retriever( |
|
search_kwargs=self.llm_loader.search_kwargs |
|
) |
|
|
|
if os.environ.get("CHAT_HISTORY_ENABLED") == "true": |
|
memory = CustomizedConversationSummaryBufferMemory( |
|
llm=self.llm_loader.llm, |
|
output_key="answer", |
|
memory_key="chat_history", |
|
max_token_limit=1024, |
|
return_messages=True, |
|
) |
|
qa = ConversationalRetrievalChain.from_llm( |
|
self.llm_loader.llm, |
|
memory=memory, |
|
chain_type="stuff", |
|
retriever=retriever, |
|
get_chat_history=lambda h: h, |
|
return_source_documents=True, |
|
) |
|
else: |
|
qa = ConversationalRetrievalChain.from_llm( |
|
self.llm_loader.llm, |
|
retriever=retriever, |
|
max_tokens_limit=8192, |
|
return_source_documents=True, |
|
) |
|
|
|
return qa |
|
|
|
def _process_inputs(self, inputs): |
|
if isinstance(inputs, list) and self.llm_loader.llm_model_type == "huggingface": |
|
inputs = [self.get_prompt(i) for i in inputs] |
|
|
|
if get_debug(): |
|
print("_process_inputs:", json.dumps(inputs, indent=4)) |
|
|
|
return inputs |
|
|
|
def get_prompt(self, inputs): |
|
qa_system_prompt = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer." |
|
|
|
df = questions_df |
|
query = inputs["question"] |
|
|
|
|
|
filtered = df[df["question"].str.lower() == query.lower()] |
|
|
|
context = filtered.iloc[0]["context"] if len(filtered) > 0 else "" |
|
|
|
if apply_chat_template_for_rag: |
|
return self.apply_chat_template( |
|
f"{qa_system_prompt}\n\n{context}\n\nQuestion: {query}" |
|
) |
|
else: |
|
return f"{qa_system_prompt}\n\n{context}\n\nQuestion: {query}\n\nHelpful Answer:" |
|
|