import faiss import numpy as np from typing import List, Dict from .embeddings import EmbeddingModel from .reranker import Reranker from sentence_transformers import SentenceTransformer import logging from sklearn.metrics.pairwise import cosine_similarity import torch class RankingSystem: def __init__(self, embedding_model: EmbeddingModel = None, reranker: Reranker = None): self.embedding_model = embedding_model or EmbeddingModel() self.reranker = reranker or Reranker() self.index = None self.passages = None self.embedding_cache = {} def build_index(self, passages: List[Dict]): """构建FAISS索引""" self.passages = passages texts = [p['passage'] for p in passages] if not texts: logging.warning("没有文本需要编码") return embeddings = self.embedding_model.encode(texts) if embeddings is None or not hasattr(embeddings, 'shape'): logging.error("编码结果为空或格式不正确") return dimension = embeddings.shape[1] self.index = faiss.IndexFlatIP(dimension) self.index.add(embeddings.astype('float32')) def initial_ranking(self, query: str, passages: List[Dict], initial_top_k: int = 10) -> List[Dict]: """对文档进行初始排序并返回前K个结果""" # 确保输入格式正确 if not isinstance(passages[0], dict): passages = [{'passage': p} for p in passages] # 使用缓存的嵌入 texts = [p['passage'] for p in passages] embeddings = [] for text in texts: if text in self.embedding_cache: embeddings.append(self.embedding_cache[text]) else: embedding = self.embedding_model.encode([text])[0] self.embedding_cache[text] = embedding embeddings.append(embedding) embeddings = np.array(embeddings) # 批量计算相似度 query_embedding = self.embedding_model.encode([query])[0] similarities = np.dot(embeddings, query_embedding) # 快速排序 indices = np.argsort(similarities)[::-1][:initial_top_k] ranked_passages = [] for idx in indices: passage = passages[idx].copy() passage['retrieval_score'] = float(similarities[idx]) ranked_passages.append(passage) return ranked_passages def rerank(self, query: str, initial_ranked: List[Dict], final_top_k: int = 3) -> List[Dict]: """使用重排序器进行重排序""" # 使用重排序器 reranked = self.reranker.rerank(query, initial_ranked) # 计算最终分数(调整权重) for passage in reranked: # 增加相关性权重 passage['final_score'] = ( 0.3 * passage['retrieval_score'] + 0.7 * passage['rerank_score'] ) # 按最终分数排序 final_ranked = sorted( reranked, key=lambda x: x['final_score'], reverse=True ) return final_ranked[:final_top_k] def retrieve(self, query: str, passages: List[Dict]) -> List[Dict]: """ 检索并排序文档 Args: query: 查询字符串 passages: 待检索的文档列表 Returns: List[Dict]: 经过排序的文档列表 """ # 1. 首先进行初始排序 initial_results = self.initial_ranking(query, passages, initial_top_k=10) # 2. 然后进行重排序 final_results = self.rerank(query, initial_results, final_top_k=3) return final_results