""" Rerank with cross encoder. Ref: https://medium.aiplanet.com/advanced-rag-cohere-re-ranker-99acc941601c https://github.com/langchain-ai/langchain/issues/13076 """ from __future__ import annotations from typing import Optional, Sequence from langchain.schema import Document from langchain.pydantic_v1 import Extra from langchain.callbacks.manager import Callbacks from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from sentence_transformers import CrossEncoder class BgeRerank(BaseDocumentCompressor): """ Re-rank with CrossEncoder. Ref: https://medium.aiplanet.com/advanced-rag-cohere-re-ranker-99acc941601c https://github.com/langchain-ai/langchain/issues/13076 good to read: https://zhuanlan.zhihu.com/p/676008717 or its source https://teemukanstren.com/2023/12/25/llmrag-based-question-answering/ """ # Note: switch to jina-turbo due to speed consideration # original was "BAAI/bge-reranker-large" model_name: str = "jinaai/jina-reranker-v1-turbo-en" """Model name to use for reranking.""" top_n: int = 6 """Number of documents to return.""" model: CrossEncoder = CrossEncoder(model_name, trust_remote_code=True) """CrossEncoder instance to use for reranking.""" def bge_rerank(self, query, docs): model_inputs = [[query, doc] for doc in docs] scores = self.model.predict(model_inputs) results = sorted(enumerate(scores), key=lambda x: x[1], reverse=True) return results[: self.top_n] class Config: """Configuration for this pydantic object.""" extra = Extra.forbid arbitrary_types_allowed = True def compress_documents( self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: """ Compress documents using BAAI/bge-reranker models. Args: documents: A sequence of documents to compress. query: The query to use for compressing the documents. callbacks: Callbacks to run during the compression process. Returns: A sequence of compressed documents. """ if len(documents) == 0: # to avoid empty api call return [] doc_list = list(documents) _docs = [d.page_content for d in doc_list] results = self.bge_rerank(query, _docs) final_results = [] for r in results: doc = doc_list[r[0]] doc.metadata["relevance_score"] = r[1] final_results.append(doc) return final_results