QA_lawyer / src /retriever.py
I77's picture
Upload 15 files
a361ca0 verified
from sentence_transformers import CrossEncoder
from sentence_transformers import SentenceTransformer
import faiss
import json
class CustomRetriever:
def __init__(self, chunks_path, embeddings_path, metadata_path, top_k=50):
self.model_bi = SentenceTransformer("deepvk/USER-bge-m3")
self.model_cross = CrossEncoder("DiTy/cross-encoder-russian-msmarco")
with open(chunks_path, "r") as f:
self.chunks = json.load(f)
self.index = faiss.read_index(embeddings_path)
self.top_k = top_k
with open(metadata_path, "r") as f:
self.metadata = json.load(f)
def retrieve(self, query):
query_vector = self.model_bi.encode([query])
faiss.normalize_L2(query_vector)
distances, indices = self.index.search(query_vector, self.top_k)
possible_answers = list()
for i in range(len(indices[0])):
possible_answers.append(self.chunks[indices[0][i]])
s = self.model_cross.rank(query, possible_answers)
context = ''
for i in range(5):
meta = self.metadata[str(indices[0][s[i]["corpus_id"]])]
context += f"Факт {str(i + 1)}: {possible_answers[s[i]['corpus_id']]}. Источник:\nкнига - {meta['book']}\nномер статьи - {meta['article_num']}\nссылка на книгу - {meta['link']}\n"
return context