|
from sentence_transformers import CrossEncoder
|
|
from sentence_transformers import SentenceTransformer
|
|
import faiss
|
|
import json
|
|
|
|
|
|
class CustomRetriever:
|
|
def __init__(self, chunks_path, embeddings_path, metadata_path, top_k=50):
|
|
self.model_bi = SentenceTransformer("deepvk/USER-bge-m3")
|
|
self.model_cross = CrossEncoder("DiTy/cross-encoder-russian-msmarco")
|
|
with open(chunks_path, "r") as f:
|
|
self.chunks = json.load(f)
|
|
self.index = faiss.read_index(embeddings_path)
|
|
self.top_k = top_k
|
|
with open(metadata_path, "r") as f:
|
|
self.metadata = json.load(f)
|
|
|
|
def retrieve(self, query):
|
|
query_vector = self.model_bi.encode([query])
|
|
faiss.normalize_L2(query_vector)
|
|
distances, indices = self.index.search(query_vector, self.top_k)
|
|
possible_answers = list()
|
|
for i in range(len(indices[0])):
|
|
possible_answers.append(self.chunks[indices[0][i]])
|
|
s = self.model_cross.rank(query, possible_answers)
|
|
context = ''
|
|
for i in range(5):
|
|
meta = self.metadata[str(indices[0][s[i]["corpus_id"]])]
|
|
context += f"Факт {str(i + 1)}: {possible_answers[s[i]['corpus_id']]}. Источник:\nкнига - {meta['book']}\nномер статьи - {meta['article_num']}\nссылка на книгу - {meta['link']}\n"
|
|
return context
|
|
|