Spaces:
Running
Running
File size: 4,467 Bytes
7b7cab6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from langchain.chains import ConversationalRetrievalChain
from langchain.prompts import ChatPromptTemplate
class RetrievalChain:
def __init__(self, llm, retriever, user_prompt, verbose=False):
"""
Initializes the RetrievalChain with an LLM and retriever.
Args:
llm: Language model to use for the conversational chain.
retriever: Retriever object to fetch relevant documents.
user_prompt: Custom prompt to guide the chain.
verbose (bool): Whether to print verbose chain outputs.
"""
self.llm = llm
self.chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
return_source_documents=True,
chain_type='stuff',
combine_docs_chain_kwargs={"prompt": user_prompt},
verbose=verbose,
)
def summarize_messages(self, chat_history):
"""
Summarizes the chat history into a concise message.
Args:
chat_history: The chat history object for the session.
Returns:
bool: True if summarization is successful, False otherwise.
"""
stored_messages = chat_history.messages
if len(stored_messages) == 0:
return False
summarization_prompt = ChatPromptTemplate.from_messages(
[
("placeholder", "{chat_history}"),
(
"human",
"Summarize the above chat messages into a single concise message. Include only the important specific details.",
),
]
)
# Create a chain for summarization by piping the prompt into the language model.
summarization_chain = summarization_prompt | self.llm
summary_message = summarization_chain.invoke({"chat_history": stored_messages})
chat_history.clear() # Clear the existing chat history
chat_history.add_ai_message(summary_message.content) # Add the summary message as the first entry
return True
def stream_chat_response(self, query, chat_id, get_chat_history, initialize_chat_history):
"""
Streams the response to a query in real-time for a given chat session using SSE formatting.
Args:
query (str): The user's query.
chat_id (str): The unique ID of the chat session.
get_chat_history (function): Function to retrieve chat history by chat ID.
initialize_chat_history (function): Function to initialize a new chat history.
Yields:
str: Server-Sent Event (SSE) formatted string for each chunk of the response.
"""
# Retrieve the chat history for the session.
chat_message_history = get_chat_history(chat_id)
if not chat_message_history:
# If no chat history exists, initialize one.
chat_message_history = initialize_chat_history(chat_id)
# Optionally summarize previous messages.
self.summarize_messages(chat_message_history)
chat_history = chat_message_history.messages
# Prepare input data for the conversational retrieval chain.
input_data_for_chain = {
"question": query,
"chat_history": chat_history
}
# Add the user query to the chat history.
chat_message_history.add_user_message(query)
# Execute the chain in streaming mode (this assumes the chain supports a `stream` method).
response_stream = self.chain.stream(input_data_for_chain)
accumulated_response = ""
# Process the response stream and yield SSE events.
for chunk in response_stream:
if 'answer' in chunk:
accumulated_response += chunk['answer']
# Format the SSE event.
sse_event = f"data: {chunk['answer']}\n\n"
yield sse_event
else:
# Yield an SSE event with debug info if the chunk structure is unexpected.
debug_msg = f"Unexpected chunk structure: {chunk}"
yield f"data: {debug_msg}\n\n"
# Once streaming is complete, update chat history with the final response.
if accumulated_response:
chat_message_history.add_ai_message(accumulated_response)
else:
yield "data: No valid response content was generated.\n\n"
|