Spaces:
Running
Running
import faiss | |
import numpy as np | |
from typing import List, Dict | |
from .embeddings import EmbeddingModel | |
from .reranker import Reranker | |
from sentence_transformers import SentenceTransformer | |
import logging | |
from sklearn.metrics.pairwise import cosine_similarity | |
import torch | |
class RankingSystem: | |
def __init__(self, | |
embedding_model: EmbeddingModel = None, | |
reranker: Reranker = None): | |
self.embedding_model = embedding_model or EmbeddingModel() | |
self.reranker = reranker or Reranker() | |
self.index = None | |
self.passages = None | |
self.embedding_cache = {} | |
def build_index(self, passages: List[Dict]): | |
"""构建FAISS索引""" | |
self.passages = passages | |
texts = [p['passage'] for p in passages] | |
if not texts: | |
logging.warning("没有文本需要编码") | |
return | |
embeddings = self.embedding_model.encode(texts) | |
if embeddings is None or not hasattr(embeddings, 'shape'): | |
logging.error("编码结果为空或格式不正确") | |
return | |
dimension = embeddings.shape[1] | |
self.index = faiss.IndexFlatIP(dimension) | |
self.index.add(embeddings.astype('float32')) | |
def initial_ranking(self, query: str, passages: List[Dict], initial_top_k: int = 10) -> List[Dict]: | |
"""对文档进行初始排序并返回前K个结果""" | |
# 确保输入格式正确 | |
if not isinstance(passages[0], dict): | |
passages = [{'passage': p} for p in passages] | |
# 使用缓存的嵌入 | |
texts = [p['passage'] for p in passages] | |
embeddings = [] | |
for text in texts: | |
if text in self.embedding_cache: | |
embeddings.append(self.embedding_cache[text]) | |
else: | |
embedding = self.embedding_model.encode([text])[0] | |
self.embedding_cache[text] = embedding | |
embeddings.append(embedding) | |
embeddings = np.array(embeddings) | |
# 批量计算相似度 | |
query_embedding = self.embedding_model.encode([query])[0] | |
similarities = np.dot(embeddings, query_embedding) | |
# 快速排序 | |
indices = np.argsort(similarities)[::-1][:initial_top_k] | |
ranked_passages = [] | |
for idx in indices: | |
passage = passages[idx].copy() | |
passage['retrieval_score'] = float(similarities[idx]) | |
ranked_passages.append(passage) | |
return ranked_passages | |
def rerank(self, query: str, initial_ranked: List[Dict], final_top_k: int = 3) -> List[Dict]: | |
"""使用重排序器进行重排序""" | |
# 使用重排序器 | |
reranked = self.reranker.rerank(query, initial_ranked) | |
# 计算最终分数(调整权重) | |
for passage in reranked: | |
# 增加相关性权重 | |
passage['final_score'] = ( | |
0.3 * passage['retrieval_score'] + | |
0.7 * passage['rerank_score'] | |
) | |
# 按最终分数排序 | |
final_ranked = sorted( | |
reranked, | |
key=lambda x: x['final_score'], | |
reverse=True | |
) | |
return final_ranked[:final_top_k] | |
def retrieve(self, query: str, passages: List[Dict]) -> List[Dict]: | |
""" | |
检索并排序文档 | |
Args: | |
query: 查询字符串 | |
passages: 待检索的文档列表 | |
Returns: | |
List[Dict]: 经过排序的文档列表 | |
""" | |
# 1. 首先进行初始排序 | |
initial_results = self.initial_ranking(query, passages, initial_top_k=10) | |
# 2. 然后进行重排序 | |
final_results = self.rerank(query, initial_results, final_top_k=3) | |
return final_results |