Spaces:
Sleeping
Sleeping
Léo Bourrel
commited on
Commit
·
9f90955
1
Parent(s):
5ffa07c
feat: add reranking with keyword extractor
Browse files
sorbobotapp/app.py
CHANGED
@@ -7,7 +7,6 @@ from chain import get_chain
|
|
7 |
from chat_history import insert_chat_history, insert_chat_history_articles
|
8 |
from connection import connect
|
9 |
from css import load_css
|
10 |
-
from keyword_extraction import KeywordExtractor
|
11 |
from langchain.callbacks import get_openai_callback
|
12 |
from message import Message
|
13 |
|
@@ -27,8 +26,6 @@ def initialize_session_state():
|
|
27 |
st.session_state.token_count = 0
|
28 |
if "conversation" not in st.session_state:
|
29 |
st.session_state.conversation = get_chain(conn)
|
30 |
-
if "keyword_extractor" not in st.session_state:
|
31 |
-
st.session_state.keyword_extractor = KeywordExtractor()
|
32 |
|
33 |
|
34 |
def send_message_callback():
|
|
|
7 |
from chat_history import insert_chat_history, insert_chat_history_articles
|
8 |
from connection import connect
|
9 |
from css import load_css
|
|
|
10 |
from langchain.callbacks import get_openai_callback
|
11 |
from message import Message
|
12 |
|
|
|
26 |
st.session_state.token_count = 0
|
27 |
if "conversation" not in st.session_state:
|
28 |
st.session_state.conversation = get_chain(conn)
|
|
|
|
|
29 |
|
30 |
|
31 |
def send_message_callback():
|
sorbobotapp/conversation_retrieval_chain.py
CHANGED
@@ -1,12 +1,17 @@
|
|
1 |
import inspect
|
|
|
2 |
from typing import Any, Dict, Optional
|
3 |
|
|
|
4 |
from langchain.callbacks.manager import CallbackManagerForChainRun
|
5 |
from langchain.chains.conversational_retrieval.base import (
|
6 |
ConversationalRetrievalChain, _get_chat_history)
|
|
|
7 |
|
8 |
|
9 |
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
|
|
|
|
10 |
def _handle_docs(self, docs):
|
11 |
if len(docs) == 0:
|
12 |
return False, "No documents found. Can you rephrase ?"
|
@@ -16,6 +21,33 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
|
16 |
return False, "Too many documents found. Can you specify your request ?"
|
17 |
return True, ""
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def _call(
|
20 |
self,
|
21 |
inputs: Dict[str, Any],
|
@@ -40,6 +72,7 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
|
40 |
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
|
41 |
else:
|
42 |
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
|
|
|
43 |
valid_docs, message = self._handle_docs(docs)
|
44 |
if not valid_docs:
|
45 |
return {
|
@@ -47,6 +80,9 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
|
47 |
"source_documents": docs,
|
48 |
}
|
49 |
|
|
|
|
|
|
|
50 |
new_inputs = inputs.copy()
|
51 |
if self.rephrase_question:
|
52 |
new_inputs["question"] = new_question
|
|
|
1 |
import inspect
|
2 |
+
import json
|
3 |
from typing import Any, Dict, Optional
|
4 |
|
5 |
+
from keyword_extraction import KeywordExtractor
|
6 |
from langchain.callbacks.manager import CallbackManagerForChainRun
|
7 |
from langchain.chains.conversational_retrieval.base import (
|
8 |
ConversationalRetrievalChain, _get_chat_history)
|
9 |
+
from langchain.schema import Document
|
10 |
|
11 |
|
12 |
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
13 |
+
keyword_extractor: KeywordExtractor = KeywordExtractor()
|
14 |
+
|
15 |
def _handle_docs(self, docs):
|
16 |
if len(docs) == 0:
|
17 |
return False, "No documents found. Can you rephrase ?"
|
|
|
21 |
return False, "Too many documents found. Can you specify your request ?"
|
22 |
return True, ""
|
23 |
|
24 |
+
def rerank_documents(self, question: str, docs: list[Document]) -> list[Document]:
|
25 |
+
"""Rerank documents based on the number of similar keywords
|
26 |
+
|
27 |
+
Args:
|
28 |
+
question (str): Orinal question
|
29 |
+
docs (list[Document]): List of documents
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
list[Document]: List of documents sorted by the number of similar keywords
|
33 |
+
"""
|
34 |
+
keywords = self.keyword_extractor(question)
|
35 |
+
|
36 |
+
for doc in docs:
|
37 |
+
doc.metadata["similar_keyword"] = 0
|
38 |
+
doc_keywords = json.loads(doc.page_content)["keywords"]
|
39 |
+
if doc_keywords is None:
|
40 |
+
continue
|
41 |
+
doc_keywords = doc_keywords.lower().split(",")
|
42 |
+
|
43 |
+
for kw in keywords:
|
44 |
+
if kw.lower() in doc_keywords:
|
45 |
+
doc.metadata["similar_keyword"] += 1
|
46 |
+
print("similar keyword : ", kw)
|
47 |
+
|
48 |
+
docs = sorted(docs, key=lambda x: x.metadata["similar_keyword"])
|
49 |
+
return docs
|
50 |
+
|
51 |
def _call(
|
52 |
self,
|
53 |
inputs: Dict[str, Any],
|
|
|
72 |
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
|
73 |
else:
|
74 |
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
|
75 |
+
|
76 |
valid_docs, message = self._handle_docs(docs)
|
77 |
if not valid_docs:
|
78 |
return {
|
|
|
80 |
"source_documents": docs,
|
81 |
}
|
82 |
|
83 |
+
# Add reranking
|
84 |
+
docs = self.rerank_documents(new_question, docs)
|
85 |
+
|
86 |
new_inputs = inputs.copy()
|
87 |
if self.rephrase_question:
|
88 |
new_inputs["question"] = new_question
|