Léo Bourrel commited on
Commit
9f90955
·
1 Parent(s): 5ffa07c

feat: add reranking with keyword extractor

Browse files
sorbobotapp/app.py CHANGED
@@ -7,7 +7,6 @@ from chain import get_chain
7
  from chat_history import insert_chat_history, insert_chat_history_articles
8
  from connection import connect
9
  from css import load_css
10
- from keyword_extraction import KeywordExtractor
11
  from langchain.callbacks import get_openai_callback
12
  from message import Message
13
 
@@ -27,8 +26,6 @@ def initialize_session_state():
27
  st.session_state.token_count = 0
28
  if "conversation" not in st.session_state:
29
  st.session_state.conversation = get_chain(conn)
30
- if "keyword_extractor" not in st.session_state:
31
- st.session_state.keyword_extractor = KeywordExtractor()
32
 
33
 
34
  def send_message_callback():
 
7
  from chat_history import insert_chat_history, insert_chat_history_articles
8
  from connection import connect
9
  from css import load_css
 
10
  from langchain.callbacks import get_openai_callback
11
  from message import Message
12
 
 
26
  st.session_state.token_count = 0
27
  if "conversation" not in st.session_state:
28
  st.session_state.conversation = get_chain(conn)
 
 
29
 
30
 
31
  def send_message_callback():
sorbobotapp/conversation_retrieval_chain.py CHANGED
@@ -1,12 +1,17 @@
1
  import inspect
 
2
  from typing import Any, Dict, Optional
3
 
 
4
  from langchain.callbacks.manager import CallbackManagerForChainRun
5
  from langchain.chains.conversational_retrieval.base import (
6
  ConversationalRetrievalChain, _get_chat_history)
 
7
 
8
 
9
  class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
 
 
10
  def _handle_docs(self, docs):
11
  if len(docs) == 0:
12
  return False, "No documents found. Can you rephrase ?"
@@ -16,6 +21,33 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
16
  return False, "Too many documents found. Can you specify your request ?"
17
  return True, ""
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def _call(
20
  self,
21
  inputs: Dict[str, Any],
@@ -40,6 +72,7 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
40
  docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
41
  else:
42
  docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
 
43
  valid_docs, message = self._handle_docs(docs)
44
  if not valid_docs:
45
  return {
@@ -47,6 +80,9 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
47
  "source_documents": docs,
48
  }
49
 
 
 
 
50
  new_inputs = inputs.copy()
51
  if self.rephrase_question:
52
  new_inputs["question"] = new_question
 
1
  import inspect
2
+ import json
3
  from typing import Any, Dict, Optional
4
 
5
+ from keyword_extraction import KeywordExtractor
6
  from langchain.callbacks.manager import CallbackManagerForChainRun
7
  from langchain.chains.conversational_retrieval.base import (
8
  ConversationalRetrievalChain, _get_chat_history)
9
+ from langchain.schema import Document
10
 
11
 
12
  class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
13
+ keyword_extractor: KeywordExtractor = KeywordExtractor()
14
+
15
  def _handle_docs(self, docs):
16
  if len(docs) == 0:
17
  return False, "No documents found. Can you rephrase ?"
 
21
  return False, "Too many documents found. Can you specify your request ?"
22
  return True, ""
23
 
24
+ def rerank_documents(self, question: str, docs: list[Document]) -> list[Document]:
25
+ """Rerank documents based on the number of similar keywords
26
+
27
+ Args:
28
+ question (str): Orinal question
29
+ docs (list[Document]): List of documents
30
+
31
+ Returns:
32
+ list[Document]: List of documents sorted by the number of similar keywords
33
+ """
34
+ keywords = self.keyword_extractor(question)
35
+
36
+ for doc in docs:
37
+ doc.metadata["similar_keyword"] = 0
38
+ doc_keywords = json.loads(doc.page_content)["keywords"]
39
+ if doc_keywords is None:
40
+ continue
41
+ doc_keywords = doc_keywords.lower().split(",")
42
+
43
+ for kw in keywords:
44
+ if kw.lower() in doc_keywords:
45
+ doc.metadata["similar_keyword"] += 1
46
+ print("similar keyword : ", kw)
47
+
48
+ docs = sorted(docs, key=lambda x: x.metadata["similar_keyword"])
49
+ return docs
50
+
51
  def _call(
52
  self,
53
  inputs: Dict[str, Any],
 
72
  docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
73
  else:
74
  docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
75
+
76
  valid_docs, message = self._handle_docs(docs)
77
  if not valid_docs:
78
  return {
 
80
  "source_documents": docs,
81
  }
82
 
83
+ # Add reranking
84
+ docs = self.rerank_documents(new_question, docs)
85
+
86
  new_inputs = inputs.copy()
87
  if self.rephrase_question:
88
  new_inputs["question"] = new_question