from .base import BaseRetriever from typing import Dict, List import torch from transformers import AutoModelForCausalLM, AutoTokenizer from sentence_transformers import SentenceTransformer import logging import os class MemoRAG(BaseRetriever): def __init__(self, config: Dict): self.config = config self.init_retriever(config) def init_retriever(self, config: Dict): memo_config = config['retrieval_settings']['methods'][1]['model_settings'] try: # 使用本地量化模型 local_model_path = "/root/.cache/modelscope/hub/MaxLeton13/chatglm3-6B-32k-int4" logging.info(f"加载本地模型: {local_model_path}") self.model = AutoModelForCausalLM.from_pretrained( local_model_path, device_map="auto", trust_remote_code=True ) self.tokenizer = AutoTokenizer.from_pretrained( local_model_path, trust_remote_code=True ) # 初始化向量检索模型 logging.info(f"加载向量检索模型: {memo_config['ret_model']}") self.embedding_model = SentenceTransformer( memo_config['ret_model'], device="cuda" if torch.cuda.is_available() else "cpu" ) # 设置缓存目录 self.cache_dir = memo_config['cache_dir'] os.makedirs(self.cache_dir, exist_ok=True) except Exception as e: logging.error(f"初始化MemoRAG失败: {str(e)}") raise def retrieve(self, query: str, context: List[Dict]) -> List[Dict]: try: # 使用向量检索进行初步筛选 query_embedding = self.embedding_model.encode(query) # 计算文档嵌入 docs_text = [doc['passage'] for doc in context] docs_embeddings = self.embedding_model.encode(docs_text) # 计算相似度 similarities = torch.nn.functional.cosine_similarity( torch.tensor(query_embedding).unsqueeze(0), torch.tensor(docs_embeddings), dim=1 ) # 为每个文档添加分数 scored_docs = [] for doc, score in zip(context, similarities): doc_copy = doc.copy() doc_copy['memory_score'] = float(score) scored_docs.append(doc_copy) # 按分数排序 return sorted(scored_docs, key=lambda x: x['memory_score'], reverse=True) except Exception as e: logging.error(f"MemoRAG检索失败: {str(e)}") # 如果检索失败,返回原始文档列表 return context