|
import shutil |
|
from pathlib import Path |
|
|
|
import torch |
|
from loguru import logger |
|
|
|
from ....web_configs import WEB_CONFIGS |
|
from ...database.product_db import get_db_product_info |
|
from .feature_store import gen_vector_db |
|
from .retriever import CacheRetriever |
|
|
|
|
|
CONTEXT_MAX_LENGTH = 3000 |
|
GENERATE_TEMPLATE = "这是说明书:“{}”\n 客户的问题:“{}” \n 请阅读说明并运用你的性格进行解答。" |
|
|
|
|
|
RAG_RETRIEVER = None |
|
|
|
|
|
def build_rag_prompt(rag_retriever: CacheRetriever, product_name, prompt): |
|
|
|
real_retriever = rag_retriever.get(fs_id="default") |
|
|
|
if isinstance(real_retriever, tuple): |
|
logger.info(f" @@@ GOT real_retriever == tuple : {real_retriever}") |
|
return "" |
|
|
|
chunk, db_context, references = real_retriever.query( |
|
f"商品名:{product_name}。{prompt}", context_max_length=CONTEXT_MAX_LENGTH - 2 * len(GENERATE_TEMPLATE) |
|
) |
|
logger.info(f"db_context = {db_context}") |
|
|
|
if db_context is not None and len(db_context) > 1: |
|
prompt_rag = GENERATE_TEMPLATE.format(db_context, prompt) |
|
else: |
|
logger.info("db_context get error") |
|
prompt_rag = prompt |
|
|
|
logger.info(f"RAG reference = {references}") |
|
logger.info("=" * 20) |
|
|
|
return prompt_rag |
|
|
|
|
|
def init_rag_retriever(rag_config: str, db_path: str): |
|
torch.cuda.empty_cache() |
|
|
|
retriever = CacheRetriever(config_path=rag_config) |
|
|
|
|
|
retriever.get(fs_id="default", config_path=rag_config, work_dir=db_path) |
|
|
|
return retriever |
|
|
|
|
|
async def gen_rag_db(user_id, force_gen=False): |
|
""" |
|
生成向量数据库。 |
|
|
|
参数: |
|
force_gen - 布尔值,当设置为 True 时,即使数据库已存在也会重新生成数据库。 |
|
""" |
|
|
|
|
|
if Path(WEB_CONFIGS.RAG_VECTOR_DB_DIR).exists() and not force_gen: |
|
return |
|
|
|
if force_gen and Path(WEB_CONFIGS.RAG_VECTOR_DB_DIR).exists(): |
|
shutil.rmtree(WEB_CONFIGS.RAG_VECTOR_DB_DIR) |
|
|
|
|
|
if Path(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP).exists(): |
|
shutil.rmtree(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP) |
|
Path(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP).mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
product_list, _ = await get_db_product_info(user_id) |
|
|
|
for info in product_list: |
|
|
|
shutil.copyfile( |
|
Path( |
|
WEB_CONFIGS.SERVER_FILE_ROOT, |
|
WEB_CONFIGS.PRODUCT_FILE_DIR, |
|
WEB_CONFIGS.INSTRUCTIONS_DIR, |
|
Path(info.instruction).name, |
|
), |
|
Path(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP).joinpath(Path(info.instruction).name), |
|
) |
|
|
|
logger.info("Generating rag database, pls wait ...") |
|
|
|
gen_vector_db( |
|
WEB_CONFIGS.RAG_CONFIG_PATH, |
|
str(Path(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP).absolute()), |
|
WEB_CONFIGS.RAG_VECTOR_DB_DIR, |
|
) |
|
|
|
|
|
shutil.rmtree(WEB_CONFIGS.PRODUCT_INSTRUCTION_DIR_GEN_DB_TMP) |
|
|
|
|
|
async def load_rag_model(user_id): |
|
|
|
global RAG_RETRIEVER |
|
|
|
|
|
await gen_rag_db(user_id) |
|
|
|
|
|
RAG_RETRIEVER = init_rag_retriever(rag_config=WEB_CONFIGS.RAG_CONFIG_PATH, db_path=WEB_CONFIGS.RAG_VECTOR_DB_DIR) |
|
logger.info("load rag model done !...") |
|
|
|
|
|
async def rebuild_rag_db(user_id, db_name="default"): |
|
|
|
|
|
await gen_rag_db(user_id, force_gen=True) |
|
|
|
|
|
RAG_RETRIEVER.pop(db_name) |
|
RAG_RETRIEVER.get(fs_id=db_name, config_path=WEB_CONFIGS.RAG_CONFIG_PATH, work_dir=WEB_CONFIGS.RAG_VECTOR_DB_DIR) |
|
|