Spaces:
Running
Running
from langchain.chains import ConversationalRetrievalChain | |
from langchain.prompts import ChatPromptTemplate | |
class RetrievalChain: | |
def __init__(self, llm, retriever, user_prompt, verbose=False): | |
""" | |
Initializes the RetrievalChain with an LLM and retriever. | |
Args: | |
llm: Language model to use for the conversational chain. | |
retriever: Retriever object to fetch relevant documents. | |
user_prompt: Custom prompt to guide the chain. | |
verbose (bool): Whether to print verbose chain outputs. | |
""" | |
self.llm = llm | |
self.chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=retriever, | |
return_source_documents=True, | |
chain_type='stuff', | |
combine_docs_chain_kwargs={"prompt": user_prompt}, | |
verbose=verbose, | |
) | |
def summarize_messages(self, chat_history): | |
""" | |
Summarizes the chat history into a concise message. | |
Args: | |
chat_history: The chat history object for the session. | |
Returns: | |
bool: True if summarization is successful, False otherwise. | |
""" | |
stored_messages = chat_history.messages | |
if len(stored_messages) == 0: | |
return False | |
summarization_prompt = ChatPromptTemplate.from_messages( | |
[ | |
("placeholder", "{chat_history}"), | |
( | |
"human", | |
"Summarize the above chat messages into a single concise message. Include only the important specific details.", | |
), | |
] | |
) | |
# Create a chain for summarization by piping the prompt into the language model. | |
summarization_chain = summarization_prompt | self.llm | |
summary_message = summarization_chain.invoke({"chat_history": stored_messages}) | |
chat_history.clear() # Clear the existing chat history | |
chat_history.add_ai_message(summary_message.content) # Add the summary message as the first entry | |
return True | |
def stream_chat_response(self, query, chat_id, get_chat_history, initialize_chat_history): | |
""" | |
Streams the response to a query in real-time for a given chat session using SSE formatting. | |
Args: | |
query (str): The user's query. | |
chat_id (str): The unique ID of the chat session. | |
get_chat_history (function): Function to retrieve chat history by chat ID. | |
initialize_chat_history (function): Function to initialize a new chat history. | |
Yields: | |
str: Server-Sent Event (SSE) formatted string for each chunk of the response. | |
""" | |
# Retrieve the chat history for the session. | |
chat_message_history = get_chat_history(chat_id) | |
if not chat_message_history: | |
# If no chat history exists, initialize one. | |
chat_message_history = initialize_chat_history(chat_id) | |
# Optionally summarize previous messages. | |
self.summarize_messages(chat_message_history) | |
chat_history = chat_message_history.messages | |
# Prepare input data for the conversational retrieval chain. | |
input_data_for_chain = { | |
"question": query, | |
"chat_history": chat_history | |
} | |
# Add the user query to the chat history. | |
chat_message_history.add_user_message(query) | |
# Execute the chain in streaming mode (this assumes the chain supports a `stream` method). | |
response_stream = self.chain.stream(input_data_for_chain) | |
accumulated_response = "" | |
# Process the response stream and yield SSE events. | |
for chunk in response_stream: | |
if 'answer' in chunk: | |
accumulated_response += chunk['answer'] | |
# Format the SSE event. | |
sse_event = f"data: {chunk['answer']}\n\n" | |
yield sse_event | |
else: | |
# Yield an SSE event with debug info if the chunk structure is unexpected. | |
debug_msg = f"Unexpected chunk structure: {chunk}" | |
yield f"data: {debug_msg}\n\n" | |
# Once streaming is complete, update chat history with the final response. | |
if accumulated_response: | |
chat_message_history.add_ai_message(accumulated_response) | |
else: | |
yield "data: No valid response content was generated.\n\n" | |