File size: 2,892 Bytes
7cc8bc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from .base import BaseRetriever
from typing import Dict, List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
import logging
import os

class MemoRAG(BaseRetriever):
    def __init__(self, config: Dict):
        self.config = config
        self.init_retriever(config)
        
    def init_retriever(self, config: Dict):
        memo_config = config['retrieval_settings']['methods'][1]['model_settings']
        
        try:
            # 使用本地量化模型
            local_model_path = "/root/.cache/modelscope/hub/MaxLeton13/chatglm3-6B-32k-int4"
            
            logging.info(f"加载本地模型: {local_model_path}")
            self.model = AutoModelForCausalLM.from_pretrained(
                local_model_path,
                device_map="auto",
                trust_remote_code=True
            )
            self.tokenizer = AutoTokenizer.from_pretrained(
                local_model_path,
                trust_remote_code=True
            )
            
            # 初始化向量检索模型
            logging.info(f"加载向量检索模型: {memo_config['ret_model']}")
            self.embedding_model = SentenceTransformer(
                memo_config['ret_model'],
                device="cuda" if torch.cuda.is_available() else "cpu"
            )
            
            # 设置缓存目录
            self.cache_dir = memo_config['cache_dir']
            os.makedirs(self.cache_dir, exist_ok=True)
            
        except Exception as e:
            logging.error(f"初始化MemoRAG失败: {str(e)}")
            raise
            
    def retrieve(self, query: str, context: List[Dict]) -> List[Dict]:
        try:
            # 使用向量检索进行初步筛选
            query_embedding = self.embedding_model.encode(query)
            
            # 计算文档嵌入
            docs_text = [doc['passage'] for doc in context]
            docs_embeddings = self.embedding_model.encode(docs_text)
            
            # 计算相似度
            similarities = torch.nn.functional.cosine_similarity(
                torch.tensor(query_embedding).unsqueeze(0),
                torch.tensor(docs_embeddings),
                dim=1
            )
            
            # 为每个文档添加分数
            scored_docs = []
            for doc, score in zip(context, similarities):
                doc_copy = doc.copy()
                doc_copy['memory_score'] = float(score)
                scored_docs.append(doc_copy)
            
            # 按分数排序
            return sorted(scored_docs, key=lambda x: x['memory_score'], reverse=True)
            
        except Exception as e:
            logging.error(f"MemoRAG检索失败: {str(e)}")
            # 如果检索失败,返回原始文档列表
            return context