LLM_demo / knowledge_base.py
chenjianfei
Add application file
49e5e54
from langchain_community.document_loaders import DirectoryLoader, JSONLoader, UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, MarkdownTextSplitter, MarkdownHeaderTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from pathlib import Path
from transformers import AutoModel, AutoTokenizer
import torch
import torch.nn.functional as F
import numpy as np
import config as cfg
class LocalRAG:
def __init__(self,
rag_top_k=3,
doc_dir="rag/kb/BIGOLIVE及公司介绍/", # 默认加载这个,若选择角色扮演模式,可根据角色选择
vector_db_path="rag/vector_db/",
embed_model=cfg.DEFAULT_EMBEDDING_MODEL
):
self.rag_top_k = rag_top_k
self.doc_dir = doc_dir # 本地知识库的文档目录
self.vector_db_path = vector_db_path # 向量数据库存储路径
self.embed_model = embed_model
self.build_vector_db()
def build_vector_db(self):
# 加载文档(支持PDF、TXT、DOCX)
if isinstance(self.doc_dir, list):
general_docs = []
json_docs = []
md_docs = []
for doc_dir in self.doc_dir:
# 处理一般文件,txt等
loader = DirectoryLoader(doc_dir, glob="**/*.[!json!md]*") # "**/[!.]*"
tmp_docs = loader.load()
general_docs.extend(tmp_docs)
# 额外处理json文件
for json_file in Path(doc_dir).rglob("*.json"):
loader = JSONLoader(
file_path=str(json_file),
jq_schema=".[] | {spk: .spk, text: .text}",
text_content=False)
data = loader.load()
for iidx in range(len(data)):
data[iidx].page_content = bytes(data[iidx].page_content, "utf-8").decode("unicode_escape")
json_docs.extend(data)
# 额外处理md文件
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
]
for md_file in Path(doc_dir).rglob("*.md"):
with open(md_file, 'r') as f:
content = f.read()
# 定义拆分器,拆分markdown内容
markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
md_header_splits = markdown_splitter.split_text(content)
md_docs.extend(md_header_splits)
# loader = UnstructuredMarkdownLoader(md_file, mode="elements")
# data = loader.load()
# docs.extend(data)
else:
loader = DirectoryLoader(self.doc_dir, glob="**/*.*")
docs = loader.load()
# 文本分块
if len(general_docs) > 0:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50
)
chunks = text_splitter.split_documents(docs)
else:
chunks = json_docs + md_docs
# 生成向量并构建FAISS数据库
embeddings = HuggingFaceEmbeddings(model_name=self.embed_model)
self.vector_db = FAISS.from_documents(chunks, embeddings)
self.vector_db.save_local(self.vector_db_path)
def reload_knowledge_base(self, target_doc_dir):
self.doc_dir = target_doc_dir
self.build_vector_db()
# def reset(self):
# self.vector_db = None
class LocalRAG_new:
def __init__(self,
rag_top_k=3,
doc_dir="rag/kb/BIGOLIVE及公司介绍/", # 默认加载这个,若选择角色扮演模式,可根据角色选择
vector_db_path="rag/vector_db/",
embed_model_path="princeton-nlp/sup-simcse-bert-large-uncased",
device=torch.device('cuda:2')):
self.rag_top_k = rag_top_k
self.doc_dir = doc_dir # 本地知识库的文档目录
self.kb_name = '_'.join([Path(doc_dir[i]).name for i in range(len(doc_dir))])
self.embed_model_name = Path(embed_model_path).name
self.vector_db_path = vector_db_path # 向量数据库存储路径
self.embed_model = embed_model_path
self.device = device
# 加载分词器和模型
self.tokenizer = AutoTokenizer.from_pretrained(self.embed_model)
self.embed_model = AutoModel.from_pretrained(self.embed_model).to(device)
self.vector_db = None
self._vector_db = None
self.build_vector_db()
class VectorDB:
def __init__(self, rag):
self._data = rag._vector_db
self.rag = rag
def similarity_search(self, query, k):
# 可能的输入预处理,暂无
# query = input_optimize(query)
# 计算query的embedding并与库中比较
with torch.inference_mode():
query_token = self.rag.tokenizer(query, padding=True, truncation=False, return_tensors="pt").to(self.rag.device)
query_embed = self.rag.embed_model(**query_token)['last_hidden_state'].mean(dim=1)
sim_query = F.cosine_similarity(query_embed.repeat(len(self._data['embeds']), 1), self._data['embeds'], dim=1, eps=1e-8)
max_ids_query = torch.argsort(sim_query, descending=True)[:self.rag.rag_top_k].cpu().detach().numpy()
return list(zip(np.array(self._data['chunks'])[max_ids_query], sim_query[max_ids_query]))
def build_vector_db(self):
# 加载文档(支持PDF、TXT、DOCX)
if isinstance(self.doc_dir, list):
docs = []
for doc_dir in self.doc_dir:
loader = DirectoryLoader(doc_dir, glob="**/*.[!json!md]*") # "**/[!.]*"
tmp_docs = loader.load()
docs.extend(tmp_docs)
# # 额外处理json文件
# for json_file in Path(doc_dir).rglob("*.json"):
# loader = JSONLoader(
# file_path=str(json_file),
# jq_schema='.messages[].content',
# text_content=False)
# data = loader.load()
# 额外处理md文件
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
]
for md_file in Path(doc_dir).rglob("*.md"):
with open(md_file, 'r') as f:
content = f.read()
# 定义拆分器,拆分markdown内容
markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
md_header_splits = markdown_splitter.split_text(content)
docs.extend(md_header_splits)
# loader = UnstructuredMarkdownLoader(md_file, mode="elements")
# data = loader.load()
# docs.extend(data)
else:
loader = DirectoryLoader(self.doc_dir, glob="**/*.*")
docs = loader.load()
# 文本分块
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50
)
chunks = text_splitter.split_documents(docs)
with torch.inference_mode():
chunk_and_embed = []
for chunk in chunks:
chunk_token = self.tokenizer(chunk.page_content, padding=True, truncation=False, return_tensors="pt").to(self.device)
chunk_embed = self.embed_model(**chunk_token)['last_hidden_state'].mean(dim=1)
chunk_and_embed.append((chunk, chunk_embed))
all_chunks, all_embeds = list(zip(*chunk_and_embed))
all_chunks, all_embeds = list(all_chunks), list(all_embeds)
all_embeds = torch.cat(all_embeds, dim=0)
self._vector_db = {'chunks': all_chunks, 'embeds': all_embeds}
self.vector_db = self.VectorDB(self)
torch.save(self.vector_db, str(Path(self.vector_db_path) / f'{self.kb_name}_{self.embed_model_name}.pt'))
def reload_knowledge_base(self, target_doc_dir):
self.doc_dir = target_doc_dir
self.build_vector_db()
# def reset(self):
# self.vector_db = None
class CosPlayer:
def __init__(self, description_file):
self.update(description_file)
def update(self, description_file):
self.description_file = description_file
with open(description_file, 'r') as f:
all_lines = f.readlines()
self.core_setting = ''.join(all_lines)
self.characters_dir = Path(description_file).parent
self.prologue_file = self.description_file.replace('/characters/', '/prologues/')
if not Path(self.prologue_file).exists():
self.prologue_file = None
def get_all_characters(self):
return [str(i) for i in list(self.characters_dir.rglob('*.txt'))]
def get_core_setting(self):
return self.core_setting
def get_prologue(self):
if self.prologue_file:
with open(self.prologue_file, 'r') as f:
all_lines = f.readlines()
return ''.join(all_lines)
else:
return None
if __name__ == "__main__":
rag = LocalRAG()
# # rag.build_vector_db()
# doc_dir = "rag/debug"
# loader = DirectoryLoader(doc_dir, glob="**/*.*")
# docs = loader.load()
# # 文本分块
# text_splitter = RecursiveCharacterTextSplitter(
# chunk_size=500,
# chunk_overlap=50
# )
# chunks = text_splitter.split_documents(docs)
# pass