zhuhai111's picture
Upload 43 files
7cc8bc0 verified
from .base import BaseRetriever
from typing import Dict, List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
import logging
import os
class MemoRAG(BaseRetriever):
def __init__(self, config: Dict):
self.config = config
self.init_retriever(config)
def init_retriever(self, config: Dict):
memo_config = config['retrieval_settings']['methods'][1]['model_settings']
try:
# 使用本地量化模型
local_model_path = "/root/.cache/modelscope/hub/MaxLeton13/chatglm3-6B-32k-int4"
logging.info(f"加载本地模型: {local_model_path}")
self.model = AutoModelForCausalLM.from_pretrained(
local_model_path,
device_map="auto",
trust_remote_code=True
)
self.tokenizer = AutoTokenizer.from_pretrained(
local_model_path,
trust_remote_code=True
)
# 初始化向量检索模型
logging.info(f"加载向量检索模型: {memo_config['ret_model']}")
self.embedding_model = SentenceTransformer(
memo_config['ret_model'],
device="cuda" if torch.cuda.is_available() else "cpu"
)
# 设置缓存目录
self.cache_dir = memo_config['cache_dir']
os.makedirs(self.cache_dir, exist_ok=True)
except Exception as e:
logging.error(f"初始化MemoRAG失败: {str(e)}")
raise
def retrieve(self, query: str, context: List[Dict]) -> List[Dict]:
try:
# 使用向量检索进行初步筛选
query_embedding = self.embedding_model.encode(query)
# 计算文档嵌入
docs_text = [doc['passage'] for doc in context]
docs_embeddings = self.embedding_model.encode(docs_text)
# 计算相似度
similarities = torch.nn.functional.cosine_similarity(
torch.tensor(query_embedding).unsqueeze(0),
torch.tensor(docs_embeddings),
dim=1
)
# 为每个文档添加分数
scored_docs = []
for doc, score in zip(context, similarities):
doc_copy = doc.copy()
doc_copy['memory_score'] = float(score)
scored_docs.append(doc_copy)
# 按分数排序
return sorted(scored_docs, key=lambda x: x['memory_score'], reverse=True)
except Exception as e:
logging.error(f"MemoRAG检索失败: {str(e)}")
# 如果检索失败,返回原始文档列表
return context