Toursim-Test / src /core /ranking.py
zhuhai111's picture
Upload 43 files
7cc8bc0 verified
import faiss
import numpy as np
from typing import List, Dict
from .embeddings import EmbeddingModel
from .reranker import Reranker
from sentence_transformers import SentenceTransformer
import logging
from sklearn.metrics.pairwise import cosine_similarity
import torch
class RankingSystem:
def __init__(self,
embedding_model: EmbeddingModel = None,
reranker: Reranker = None):
self.embedding_model = embedding_model or EmbeddingModel()
self.reranker = reranker or Reranker()
self.index = None
self.passages = None
self.embedding_cache = {}
def build_index(self, passages: List[Dict]):
"""构建FAISS索引"""
self.passages = passages
texts = [p['passage'] for p in passages]
if not texts:
logging.warning("没有文本需要编码")
return
embeddings = self.embedding_model.encode(texts)
if embeddings is None or not hasattr(embeddings, 'shape'):
logging.error("编码结果为空或格式不正确")
return
dimension = embeddings.shape[1]
self.index = faiss.IndexFlatIP(dimension)
self.index.add(embeddings.astype('float32'))
def initial_ranking(self, query: str, passages: List[Dict], initial_top_k: int = 10) -> List[Dict]:
"""对文档进行初始排序并返回前K个结果"""
# 确保输入格式正确
if not isinstance(passages[0], dict):
passages = [{'passage': p} for p in passages]
# 使用缓存的嵌入
texts = [p['passage'] for p in passages]
embeddings = []
for text in texts:
if text in self.embedding_cache:
embeddings.append(self.embedding_cache[text])
else:
embedding = self.embedding_model.encode([text])[0]
self.embedding_cache[text] = embedding
embeddings.append(embedding)
embeddings = np.array(embeddings)
# 批量计算相似度
query_embedding = self.embedding_model.encode([query])[0]
similarities = np.dot(embeddings, query_embedding)
# 快速排序
indices = np.argsort(similarities)[::-1][:initial_top_k]
ranked_passages = []
for idx in indices:
passage = passages[idx].copy()
passage['retrieval_score'] = float(similarities[idx])
ranked_passages.append(passage)
return ranked_passages
def rerank(self, query: str, initial_ranked: List[Dict], final_top_k: int = 3) -> List[Dict]:
"""使用重排序器进行重排序"""
# 使用重排序器
reranked = self.reranker.rerank(query, initial_ranked)
# 计算最终分数(调整权重)
for passage in reranked:
# 增加相关性权重
passage['final_score'] = (
0.3 * passage['retrieval_score'] +
0.7 * passage['rerank_score']
)
# 按最终分数排序
final_ranked = sorted(
reranked,
key=lambda x: x['final_score'],
reverse=True
)
return final_ranked[:final_top_k]
def retrieve(self, query: str, passages: List[Dict]) -> List[Dict]:
"""
检索并排序文档
Args:
query: 查询字符串
passages: 待检索的文档列表
Returns:
List[Dict]: 经过排序的文档列表
"""
# 1. 首先进行初始排序
initial_results = self.initial_ranking(query, passages, initial_top_k=10)
# 2. 然后进行重排序
final_results = self.rerank(query, initial_results, final_top_k=3)
return final_results