from typing import List, Dict from FlagEmbedding import FlagReranker import logging import torch import os from sentence_transformers import CrossEncoder class Reranker: def __init__(self, model_path="BAAI/bge-reranker-large"): try: self.model = FlagReranker( model_path, use_fp16=True, device="cuda" if torch.cuda.is_available() else "cpu" ) logging.info(f"成功加载重排序模型 {model_path} 到 {'cuda' if torch.cuda.is_available() else 'cpu'} 设备") except Exception as e: logging.error(f"加载重排序模型失败: {str(e)}") raise def rerank(self, query: str, passages: List[Dict]) -> List[Dict]: """ 对文档进行重排序 """ try: # 准备文本列表 texts = [p['passage'] for p in passages] # 执行重排序 scores = self.model.compute_score([[query, text] for text in texts]) # 将分数添加到原始字典中 for passage, score in zip(passages, scores): passage['rerank_score'] = float(score) # 按重排序分数排序 reranked = sorted(passages, key=lambda x: x['rerank_score'], reverse=True) return reranked except Exception as e: logging.error(f"重排序过程中出错: {str(e)}") # 如果重排序失败,返回原始排序 return passages