File size: 3,936 Bytes
7cc8bc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import faiss
import numpy as np
from typing import List, Dict
from .embeddings import EmbeddingModel
from .reranker import Reranker
from sentence_transformers import SentenceTransformer
import logging
from sklearn.metrics.pairwise import cosine_similarity
import torch

class RankingSystem:
    def __init__(self, 
                 embedding_model: EmbeddingModel = None,
                 reranker: Reranker = None):
        self.embedding_model = embedding_model or EmbeddingModel()
        self.reranker = reranker or Reranker()
        self.index = None
        self.passages = None
        self.embedding_cache = {}

    def build_index(self, passages: List[Dict]):
        """构建FAISS索引"""
        self.passages = passages
        texts = [p['passage'] for p in passages]
        
        if not texts:
            logging.warning("没有文本需要编码")
            return
        
        embeddings = self.embedding_model.encode(texts)
        
        if embeddings is None or not hasattr(embeddings, 'shape'):
            logging.error("编码结果为空或格式不正确")
            return
        
        dimension = embeddings.shape[1]
        self.index = faiss.IndexFlatIP(dimension)
        self.index.add(embeddings.astype('float32'))

    def initial_ranking(self, query: str, passages: List[Dict], initial_top_k: int = 10) -> List[Dict]:
        """对文档进行初始排序并返回前K个结果"""
        # 确保输入格式正确
        if not isinstance(passages[0], dict):
            passages = [{'passage': p} for p in passages]
            
        # 使用缓存的嵌入
        texts = [p['passage'] for p in passages]
        embeddings = []
        
        for text in texts:
            if text in self.embedding_cache:
                embeddings.append(self.embedding_cache[text])
            else:
                embedding = self.embedding_model.encode([text])[0]
                self.embedding_cache[text] = embedding
                embeddings.append(embedding)
                
        embeddings = np.array(embeddings)
        
        # 批量计算相似度
        query_embedding = self.embedding_model.encode([query])[0]
        similarities = np.dot(embeddings, query_embedding)
        
        # 快速排序
        indices = np.argsort(similarities)[::-1][:initial_top_k]
        
        ranked_passages = []
        for idx in indices:
            passage = passages[idx].copy()
            passage['retrieval_score'] = float(similarities[idx])
            ranked_passages.append(passage)
            
        return ranked_passages

    def rerank(self, query: str, initial_ranked: List[Dict], final_top_k: int = 3) -> List[Dict]:
        """使用重排序器进行重排序"""
        # 使用重排序器
        reranked = self.reranker.rerank(query, initial_ranked)
        
        # 计算最终分数(调整权重)
        for passage in reranked:
            # 增加相关性权重
            passage['final_score'] = (
                0.3 * passage['retrieval_score'] + 
                0.7 * passage['rerank_score']
            )
        
        # 按最终分数排序
        final_ranked = sorted(
            reranked,
            key=lambda x: x['final_score'],
            reverse=True
        )
        
        return final_ranked[:final_top_k]

    def retrieve(self, query: str, passages: List[Dict]) -> List[Dict]:
        """
        检索并排序文档
        
        Args:
            query: 查询字符串
            passages: 待检索的文档列表
            
        Returns:
            List[Dict]: 经过排序的文档列表
        """
        # 1. 首先进行初始排序
        initial_results = self.initial_ranking(query, passages, initial_top_k=10)
        
        # 2. 然后进行重排序
        final_results = self.rerank(query, initial_results, final_top_k=3)
        
        return final_results