Spaces:
Running
Running
from typing import List, Dict | |
from FlagEmbedding import FlagReranker | |
import logging | |
import torch | |
import os | |
from sentence_transformers import CrossEncoder | |
class Reranker: | |
def __init__(self, model_path="BAAI/bge-reranker-large"): | |
try: | |
self.model = FlagReranker( | |
model_path, | |
use_fp16=True, | |
device="cuda" if torch.cuda.is_available() else "cpu" | |
) | |
logging.info(f"成功加载重排序模型 {model_path} 到 {'cuda' if torch.cuda.is_available() else 'cpu'} 设备") | |
except Exception as e: | |
logging.error(f"加载重排序模型失败: {str(e)}") | |
raise | |
def rerank(self, query: str, passages: List[Dict]) -> List[Dict]: | |
""" | |
对文档进行重排序 | |
""" | |
try: | |
# 准备文本列表 | |
texts = [p['passage'] for p in passages] | |
# 执行重排序 | |
scores = self.model.compute_score([[query, text] for text in texts]) | |
# 将分数添加到原始字典中 | |
for passage, score in zip(passages, scores): | |
passage['rerank_score'] = float(score) | |
# 按重排序分数排序 | |
reranked = sorted(passages, key=lambda x: x['rerank_score'], reverse=True) | |
return reranked | |
except Exception as e: | |
logging.error(f"重排序过程中出错: {str(e)}") | |
# 如果重排序失败,返回原始排序 | |
return passages |