Toursim-Test / src /core /reranker.py
zhuhai111's picture
Upload 43 files
7cc8bc0 verified
from typing import List, Dict
from FlagEmbedding import FlagReranker
import logging
import torch
import os
from sentence_transformers import CrossEncoder
class Reranker:
def __init__(self, model_path="BAAI/bge-reranker-large"):
try:
self.model = FlagReranker(
model_path,
use_fp16=True,
device="cuda" if torch.cuda.is_available() else "cpu"
)
logging.info(f"成功加载重排序模型 {model_path}{'cuda' if torch.cuda.is_available() else 'cpu'} 设备")
except Exception as e:
logging.error(f"加载重排序模型失败: {str(e)}")
raise
def rerank(self, query: str, passages: List[Dict]) -> List[Dict]:
"""
对文档进行重排序
"""
try:
# 准备文本列表
texts = [p['passage'] for p in passages]
# 执行重排序
scores = self.model.compute_score([[query, text] for text in texts])
# 将分数添加到原始字典中
for passage, score in zip(passages, scores):
passage['rerank_score'] = float(score)
# 按重排序分数排序
reranked = sorted(passages, key=lambda x: x['rerank_score'], reverse=True)
return reranked
except Exception as e:
logging.error(f"重排序过程中出错: {str(e)}")
# 如果重排序失败,返回原始排序
return passages