File size: 1,422 Bytes
a361ca0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from sentence_transformers import CrossEncoder
from sentence_transformers import SentenceTransformer
import faiss
import json


class CustomRetriever:
    def __init__(self, chunks_path, embeddings_path, metadata_path, top_k=50):
        self.model_bi = SentenceTransformer("deepvk/USER-bge-m3")
        self.model_cross = CrossEncoder("DiTy/cross-encoder-russian-msmarco")
        with open(chunks_path, "r") as f:
            self.chunks = json.load(f)
        self.index = faiss.read_index(embeddings_path)
        self.top_k = top_k
        with open(metadata_path, "r") as f:
            self.metadata = json.load(f)

    def retrieve(self, query):
        query_vector = self.model_bi.encode([query])
        faiss.normalize_L2(query_vector)
        distances, indices = self.index.search(query_vector, self.top_k)
        possible_answers = list()
        for i in range(len(indices[0])):
            possible_answers.append(self.chunks[indices[0][i]])
        s = self.model_cross.rank(query, possible_answers)
        context = ''
        for i in range(5):
            meta = self.metadata[str(indices[0][s[i]["corpus_id"]])]
            context += f"Факт {str(i + 1)}: {possible_answers[s[i]['corpus_id']]}. Источник:\nкнига - {meta['book']}\nномер статьи - {meta['article_num']}\nссылка на книгу - {meta['link']}\n"
        return context