from langchain.chains import ConversationalRetrievalChain from langchain.prompts import ChatPromptTemplate class RetrievalChain: def __init__(self, llm, retriever, user_prompt, verbose=False): """ Initializes the RetrievalChain with an LLM and retriever. Args: llm: Language model to use for the conversational chain. retriever: Retriever object to fetch relevant documents. user_prompt: Custom prompt to guide the chain. verbose (bool): Whether to print verbose chain outputs. """ self.llm = llm self.chain = ConversationalRetrievalChain.from_llm( llm=llm, retriever=retriever, return_source_documents=True, chain_type='stuff', combine_docs_chain_kwargs={"prompt": user_prompt}, verbose=verbose, ) def summarize_messages(self, chat_history): """ Summarizes the chat history into a concise message. Args: chat_history: The chat history object for the session. Returns: bool: True if summarization is successful, False otherwise. """ stored_messages = chat_history.messages if len(stored_messages) == 0: return False summarization_prompt = ChatPromptTemplate.from_messages( [ ("placeholder", "{chat_history}"), ( "human", "Summarize the above chat messages into a single concise message. Include only the important specific details.", ), ] ) # Create a chain for summarization by piping the prompt into the language model. summarization_chain = summarization_prompt | self.llm summary_message = summarization_chain.invoke({"chat_history": stored_messages}) chat_history.clear() # Clear the existing chat history chat_history.add_ai_message(summary_message.content) # Add the summary message as the first entry return True def stream_chat_response(self, query, chat_id, get_chat_history, initialize_chat_history): """ Streams the response to a query in real-time for a given chat session using SSE formatting. Args: query (str): The user's query. chat_id (str): The unique ID of the chat session. get_chat_history (function): Function to retrieve chat history by chat ID. initialize_chat_history (function): Function to initialize a new chat history. Yields: str: Server-Sent Event (SSE) formatted string for each chunk of the response. """ # Retrieve the chat history for the session. chat_message_history = get_chat_history(chat_id) if not chat_message_history: # If no chat history exists, initialize one. chat_message_history = initialize_chat_history(chat_id) # Optionally summarize previous messages. self.summarize_messages(chat_message_history) chat_history = chat_message_history.messages # Prepare input data for the conversational retrieval chain. input_data_for_chain = { "question": query, "chat_history": chat_history } # Add the user query to the chat history. chat_message_history.add_user_message(query) # Execute the chain in streaming mode (this assumes the chain supports a `stream` method). response_stream = self.chain.stream(input_data_for_chain) accumulated_response = "" # Process the response stream and yield SSE events. for chunk in response_stream: if 'answer' in chunk: accumulated_response += chunk['answer'] # Format the SSE event. sse_event = f"data: {chunk['answer']}\n\n" yield sse_event else: # Yield an SSE event with debug info if the chunk structure is unexpected. debug_msg = f"Unexpected chunk structure: {chunk}" yield f"data: {debug_msg}\n\n" # Once streaming is complete, update chat history with the final response. if accumulated_response: chat_message_history.add_ai_message(accumulated_response) else: yield "data: No valid response content was generated.\n\n"