File size: 4,467 Bytes
7b7cab6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from langchain.chains import ConversationalRetrievalChain
from langchain.prompts import ChatPromptTemplate

class RetrievalChain:
    def __init__(self, llm, retriever, user_prompt, verbose=False):
        """
        Initializes the RetrievalChain with an LLM and retriever.
        
        Args:
            llm: Language model to use for the conversational chain.
            retriever: Retriever object to fetch relevant documents.
            user_prompt: Custom prompt to guide the chain.
            verbose (bool): Whether to print verbose chain outputs.
        """
        self.llm = llm

        self.chain = ConversationalRetrievalChain.from_llm(
            llm=llm,
            retriever=retriever,
            return_source_documents=True,
            chain_type='stuff',
            combine_docs_chain_kwargs={"prompt": user_prompt},
            verbose=verbose,
        )

    def summarize_messages(self, chat_history):
        """
        Summarizes the chat history into a concise message.

        Args:
            chat_history: The chat history object for the session.

        Returns:
            bool: True if summarization is successful, False otherwise.
        """
        stored_messages = chat_history.messages
        if len(stored_messages) == 0:
            return False

        summarization_prompt = ChatPromptTemplate.from_messages(
            [
                ("placeholder", "{chat_history}"),
                (
                    "human",
                    "Summarize the above chat messages into a single concise message. Include only the important specific details.",
                ),
            ]
        )
        # Create a chain for summarization by piping the prompt into the language model.
        summarization_chain = summarization_prompt | self.llm
        summary_message = summarization_chain.invoke({"chat_history": stored_messages})
        
        chat_history.clear()  # Clear the existing chat history
        chat_history.add_ai_message(summary_message.content)  # Add the summary message as the first entry
        return True

    def stream_chat_response(self, query, chat_id, get_chat_history, initialize_chat_history):
        """
        Streams the response to a query in real-time for a given chat session using SSE formatting.

        Args:
            query (str): The user's query.
            chat_id (str): The unique ID of the chat session.
            get_chat_history (function): Function to retrieve chat history by chat ID.
            initialize_chat_history (function): Function to initialize a new chat history.
        
        Yields:
            str: Server-Sent Event (SSE) formatted string for each chunk of the response.
        """
        # Retrieve the chat history for the session.
        chat_message_history = get_chat_history(chat_id)
        if not chat_message_history:
            # If no chat history exists, initialize one.
            chat_message_history = initialize_chat_history(chat_id)

        # Optionally summarize previous messages.
        self.summarize_messages(chat_message_history)
        chat_history = chat_message_history.messages

        # Prepare input data for the conversational retrieval chain.
        input_data_for_chain = {
            "question": query,
            "chat_history": chat_history
        }

        # Add the user query to the chat history.
        chat_message_history.add_user_message(query)

        # Execute the chain in streaming mode (this assumes the chain supports a `stream` method).
        response_stream = self.chain.stream(input_data_for_chain)

        accumulated_response = ""
        # Process the response stream and yield SSE events.
        for chunk in response_stream:
            if 'answer' in chunk:
                accumulated_response += chunk['answer']
                # Format the SSE event.
                sse_event = f"data: {chunk['answer']}\n\n"
                yield sse_event
            else:
                # Yield an SSE event with debug info if the chunk structure is unexpected.
                debug_msg = f"Unexpected chunk structure: {chunk}"
                yield f"data: {debug_msg}\n\n"

        # Once streaming is complete, update chat history with the final response.
        if accumulated_response:
            chat_message_history.add_ai_message(accumulated_response)
        else:
            yield "data: No valid response content was generated.\n\n"