Spaces:
Running
Running
File size: 2,892 Bytes
7cc8bc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from .base import BaseRetriever
from typing import Dict, List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
import logging
import os
class MemoRAG(BaseRetriever):
def __init__(self, config: Dict):
self.config = config
self.init_retriever(config)
def init_retriever(self, config: Dict):
memo_config = config['retrieval_settings']['methods'][1]['model_settings']
try:
# 使用本地量化模型
local_model_path = "/root/.cache/modelscope/hub/MaxLeton13/chatglm3-6B-32k-int4"
logging.info(f"加载本地模型: {local_model_path}")
self.model = AutoModelForCausalLM.from_pretrained(
local_model_path,
device_map="auto",
trust_remote_code=True
)
self.tokenizer = AutoTokenizer.from_pretrained(
local_model_path,
trust_remote_code=True
)
# 初始化向量检索模型
logging.info(f"加载向量检索模型: {memo_config['ret_model']}")
self.embedding_model = SentenceTransformer(
memo_config['ret_model'],
device="cuda" if torch.cuda.is_available() else "cpu"
)
# 设置缓存目录
self.cache_dir = memo_config['cache_dir']
os.makedirs(self.cache_dir, exist_ok=True)
except Exception as e:
logging.error(f"初始化MemoRAG失败: {str(e)}")
raise
def retrieve(self, query: str, context: List[Dict]) -> List[Dict]:
try:
# 使用向量检索进行初步筛选
query_embedding = self.embedding_model.encode(query)
# 计算文档嵌入
docs_text = [doc['passage'] for doc in context]
docs_embeddings = self.embedding_model.encode(docs_text)
# 计算相似度
similarities = torch.nn.functional.cosine_similarity(
torch.tensor(query_embedding).unsqueeze(0),
torch.tensor(docs_embeddings),
dim=1
)
# 为每个文档添加分数
scored_docs = []
for doc, score in zip(context, similarities):
doc_copy = doc.copy()
doc_copy['memory_score'] = float(score)
scored_docs.append(doc_copy)
# 按分数排序
return sorted(scored_docs, key=lambda x: x['memory_score'], reverse=True)
except Exception as e:
logging.error(f"MemoRAG检索失败: {str(e)}")
# 如果检索失败,返回原始文档列表
return context |