Léo Bourrel commited on
Commit
b9e3c29
·
1 Parent(s): 8505f96

feat: add custom retrieval

Browse files
Files changed (2) hide show
  1. app.py +4 -2
  2. conversation_retrieval_chain.py +64 -0
app.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import streamlit as st
5
  import streamlit.components.v1 as components
6
  from langchain.callbacks import get_openai_callback
7
- from langchain.chains import ConversationalRetrievalChain
8
  from langchain.chains.conversation.memory import ConversationBufferMemory
9
  from langchain.embeddings import GPT4AllEmbeddings
10
  from langchain.llms import OpenAI
@@ -14,6 +14,8 @@ from connection import connect
14
  from css import load_css
15
  from message import Message
16
  from vector_store import CustomVectorStore
 
 
17
 
18
  st.set_page_config(layout="wide")
19
 
@@ -50,7 +52,7 @@ def initialize_session_state():
50
  memory = ConversationBufferMemory(
51
  output_key="answer", memory_key="chat_history", return_messages=True
52
  )
53
- st.session_state.conversation = ConversationalRetrievalChain.from_llm(
54
  llm=llm,
55
  retriever=retriever,
56
  verbose=True,
 
4
  import streamlit as st
5
  import streamlit.components.v1 as components
6
  from langchain.callbacks import get_openai_callback
7
+
8
  from langchain.chains.conversation.memory import ConversationBufferMemory
9
  from langchain.embeddings import GPT4AllEmbeddings
10
  from langchain.llms import OpenAI
 
14
  from css import load_css
15
  from message import Message
16
  from vector_store import CustomVectorStore
17
+ from conversation_retrieval_chain import CustomConversationalRetrievalChain
18
+
19
 
20
  st.set_page_config(layout="wide")
21
 
 
52
  memory = ConversationBufferMemory(
53
  output_key="answer", memory_key="chat_history", return_messages=True
54
  )
55
+ st.session_state.conversation = CustomConversationalRetrievalChain.from_llm(
56
  llm=llm,
57
  retriever=retriever,
58
  verbose=True,
conversation_retrieval_chain.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Dict, Optional
3
+
4
+ from langchain.chains.conversational_retrieval.base import (
5
+ ConversationalRetrievalChain,
6
+ _get_chat_history,
7
+ )
8
+ from langchain.callbacks.manager import CallbackManagerForChainRun
9
+
10
+
11
+ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
12
+ def _handle_docs(self, docs):
13
+ if len(docs) == 0:
14
+ return False, "No documents found. Can you rephrase ?"
15
+ elif len(docs) == 1:
16
+ return False, "Only one document found. Can you rephrase ?"
17
+ elif len(docs) > 10:
18
+ return False, "Too many documents found. Can you specify your request ?"
19
+ return True, ""
20
+
21
+ def _call(
22
+ self,
23
+ inputs: Dict[str, Any],
24
+ run_manager: Optional[CallbackManagerForChainRun] = None,
25
+ ) -> Dict[str, Any]:
26
+ _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
27
+ question = inputs["question"]
28
+ get_chat_history = self.get_chat_history or _get_chat_history
29
+ chat_history_str = get_chat_history(inputs["chat_history"])
30
+
31
+ if chat_history_str:
32
+ callbacks = _run_manager.get_child()
33
+ new_question = self.question_generator.run(
34
+ question=question, chat_history=chat_history_str, callbacks=callbacks
35
+ )
36
+ else:
37
+ new_question = question
38
+ accepts_run_manager = (
39
+ "run_manager" in inspect.signature(self._get_docs).parameters
40
+ )
41
+ if accepts_run_manager:
42
+ docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
43
+ else:
44
+ docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
45
+ valid_docs, message = self._handle_docs(docs)
46
+ if not valid_docs:
47
+ return {
48
+ self.output_key: message,
49
+ "source_documents": docs,
50
+ }
51
+
52
+ new_inputs = inputs.copy()
53
+ if self.rephrase_question:
54
+ new_inputs["question"] = new_question
55
+ new_inputs["chat_history"] = chat_history_str
56
+ answer = self.combine_docs_chain.run(
57
+ input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
58
+ )
59
+ output: Dict[str, Any] = {self.output_key: answer}
60
+ if self.return_source_documents:
61
+ output["source_documents"] = docs
62
+ if self.return_generated_question:
63
+ output["generated_question"] = new_question
64
+ return output