Spaces:
Running
Running
from .base import BaseRetriever | |
from typing import Dict, List | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
import logging | |
import os | |
class MemoRAG(BaseRetriever): | |
def __init__(self, config: Dict): | |
self.config = config | |
self.init_retriever(config) | |
def init_retriever(self, config: Dict): | |
memo_config = config['retrieval_settings']['methods'][1]['model_settings'] | |
try: | |
# 使用本地量化模型 | |
local_model_path = "/root/.cache/modelscope/hub/MaxLeton13/chatglm3-6B-32k-int4" | |
logging.info(f"加载本地模型: {local_model_path}") | |
self.model = AutoModelForCausalLM.from_pretrained( | |
local_model_path, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
local_model_path, | |
trust_remote_code=True | |
) | |
# 初始化向量检索模型 | |
logging.info(f"加载向量检索模型: {memo_config['ret_model']}") | |
self.embedding_model = SentenceTransformer( | |
memo_config['ret_model'], | |
device="cuda" if torch.cuda.is_available() else "cpu" | |
) | |
# 设置缓存目录 | |
self.cache_dir = memo_config['cache_dir'] | |
os.makedirs(self.cache_dir, exist_ok=True) | |
except Exception as e: | |
logging.error(f"初始化MemoRAG失败: {str(e)}") | |
raise | |
def retrieve(self, query: str, context: List[Dict]) -> List[Dict]: | |
try: | |
# 使用向量检索进行初步筛选 | |
query_embedding = self.embedding_model.encode(query) | |
# 计算文档嵌入 | |
docs_text = [doc['passage'] for doc in context] | |
docs_embeddings = self.embedding_model.encode(docs_text) | |
# 计算相似度 | |
similarities = torch.nn.functional.cosine_similarity( | |
torch.tensor(query_embedding).unsqueeze(0), | |
torch.tensor(docs_embeddings), | |
dim=1 | |
) | |
# 为每个文档添加分数 | |
scored_docs = [] | |
for doc, score in zip(context, similarities): | |
doc_copy = doc.copy() | |
doc_copy['memory_score'] = float(score) | |
scored_docs.append(doc_copy) | |
# 按分数排序 | |
return sorted(scored_docs, key=lambda x: x['memory_score'], reverse=True) | |
except Exception as e: | |
logging.error(f"MemoRAG检索失败: {str(e)}") | |
# 如果检索失败,返回原始文档列表 | |
return context |