Spaces:
Running
Running
Léo Bourrel
commited on
Commit
·
b9e3c29
1
Parent(s):
8505f96
feat: add custom retrieval
Browse files- app.py +4 -2
- conversation_retrieval_chain.py +64 -0
app.py
CHANGED
@@ -4,7 +4,7 @@ import os
|
|
4 |
import streamlit as st
|
5 |
import streamlit.components.v1 as components
|
6 |
from langchain.callbacks import get_openai_callback
|
7 |
-
|
8 |
from langchain.chains.conversation.memory import ConversationBufferMemory
|
9 |
from langchain.embeddings import GPT4AllEmbeddings
|
10 |
from langchain.llms import OpenAI
|
@@ -14,6 +14,8 @@ from connection import connect
|
|
14 |
from css import load_css
|
15 |
from message import Message
|
16 |
from vector_store import CustomVectorStore
|
|
|
|
|
17 |
|
18 |
st.set_page_config(layout="wide")
|
19 |
|
@@ -50,7 +52,7 @@ def initialize_session_state():
|
|
50 |
memory = ConversationBufferMemory(
|
51 |
output_key="answer", memory_key="chat_history", return_messages=True
|
52 |
)
|
53 |
-
st.session_state.conversation =
|
54 |
llm=llm,
|
55 |
retriever=retriever,
|
56 |
verbose=True,
|
|
|
4 |
import streamlit as st
|
5 |
import streamlit.components.v1 as components
|
6 |
from langchain.callbacks import get_openai_callback
|
7 |
+
|
8 |
from langchain.chains.conversation.memory import ConversationBufferMemory
|
9 |
from langchain.embeddings import GPT4AllEmbeddings
|
10 |
from langchain.llms import OpenAI
|
|
|
14 |
from css import load_css
|
15 |
from message import Message
|
16 |
from vector_store import CustomVectorStore
|
17 |
+
from conversation_retrieval_chain import CustomConversationalRetrievalChain
|
18 |
+
|
19 |
|
20 |
st.set_page_config(layout="wide")
|
21 |
|
|
|
52 |
memory = ConversationBufferMemory(
|
53 |
output_key="answer", memory_key="chat_history", return_messages=True
|
54 |
)
|
55 |
+
st.session_state.conversation = CustomConversationalRetrievalChain.from_llm(
|
56 |
llm=llm,
|
57 |
retriever=retriever,
|
58 |
verbose=True,
|
conversation_retrieval_chain.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from typing import Any, Dict, Optional
|
3 |
+
|
4 |
+
from langchain.chains.conversational_retrieval.base import (
|
5 |
+
ConversationalRetrievalChain,
|
6 |
+
_get_chat_history,
|
7 |
+
)
|
8 |
+
from langchain.callbacks.manager import CallbackManagerForChainRun
|
9 |
+
|
10 |
+
|
11 |
+
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
12 |
+
def _handle_docs(self, docs):
|
13 |
+
if len(docs) == 0:
|
14 |
+
return False, "No documents found. Can you rephrase ?"
|
15 |
+
elif len(docs) == 1:
|
16 |
+
return False, "Only one document found. Can you rephrase ?"
|
17 |
+
elif len(docs) > 10:
|
18 |
+
return False, "Too many documents found. Can you specify your request ?"
|
19 |
+
return True, ""
|
20 |
+
|
21 |
+
def _call(
|
22 |
+
self,
|
23 |
+
inputs: Dict[str, Any],
|
24 |
+
run_manager: Optional[CallbackManagerForChainRun] = None,
|
25 |
+
) -> Dict[str, Any]:
|
26 |
+
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
27 |
+
question = inputs["question"]
|
28 |
+
get_chat_history = self.get_chat_history or _get_chat_history
|
29 |
+
chat_history_str = get_chat_history(inputs["chat_history"])
|
30 |
+
|
31 |
+
if chat_history_str:
|
32 |
+
callbacks = _run_manager.get_child()
|
33 |
+
new_question = self.question_generator.run(
|
34 |
+
question=question, chat_history=chat_history_str, callbacks=callbacks
|
35 |
+
)
|
36 |
+
else:
|
37 |
+
new_question = question
|
38 |
+
accepts_run_manager = (
|
39 |
+
"run_manager" in inspect.signature(self._get_docs).parameters
|
40 |
+
)
|
41 |
+
if accepts_run_manager:
|
42 |
+
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
|
43 |
+
else:
|
44 |
+
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
|
45 |
+
valid_docs, message = self._handle_docs(docs)
|
46 |
+
if not valid_docs:
|
47 |
+
return {
|
48 |
+
self.output_key: message,
|
49 |
+
"source_documents": docs,
|
50 |
+
}
|
51 |
+
|
52 |
+
new_inputs = inputs.copy()
|
53 |
+
if self.rephrase_question:
|
54 |
+
new_inputs["question"] = new_question
|
55 |
+
new_inputs["chat_history"] = chat_history_str
|
56 |
+
answer = self.combine_docs_chain.run(
|
57 |
+
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
|
58 |
+
)
|
59 |
+
output: Dict[str, Any] = {self.output_key: answer}
|
60 |
+
if self.return_source_documents:
|
61 |
+
output["source_documents"] = docs
|
62 |
+
if self.return_generated_question:
|
63 |
+
output["generated_question"] = new_question
|
64 |
+
return output
|