import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import CrossEncoderReranker class DocumentRetrieverWithReranker: def __init__(self, retriever, reranker_model_name="BAAI/bge-reranker-base", top_n=3): self.retriever = retriever self.reranker_model_name = reranker_model_name self.top_n = top_n self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") self.tokenizer = AutoTokenizer.from_pretrained(self.reranker_model_name) self.model = AutoModelForSequenceClassification.from_pretrained(self.reranker_model_name) self.model = self.model.to(self.device) self.compressor = CrossEncoderReranker(model=self, top_n=self.top_n) self.compression_retriever = ContextualCompressionRetriever( base_compressor=self.compressor, base_retriever=self.retriever ) def __call__(self, pairs): with torch.inference_mode(): inputs = self.tokenizer( pairs, padding=True, truncation=True, return_tensors="pt", max_length=512, ) inputs = inputs.to(self.device) scores = self.model(**inputs, return_dict=True).logits.view(-1).float() return scores.detach().cpu().tolist() def retrieve_and_rerank(self, query): return self.compression_retriever.invoke(query) @staticmethod def pretty_print_docs(docs): print( f"\n{'-' * 100}\n".join( [f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)] ) )