|
from langchain_community.document_loaders import DirectoryLoader, JSONLoader, UnstructuredMarkdownLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, MarkdownTextSplitter, MarkdownHeaderTextSplitter |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.vectorstores import FAISS |
|
from pathlib import Path |
|
from transformers import AutoModel, AutoTokenizer |
|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import config as cfg |
|
|
|
|
|
|
|
class LocalRAG: |
|
def __init__(self, |
|
rag_top_k=3, |
|
doc_dir="rag/kb/BIGOLIVE及公司介绍/", |
|
vector_db_path="rag/vector_db/", |
|
embed_model=cfg.DEFAULT_EMBEDDING_MODEL |
|
): |
|
self.rag_top_k = rag_top_k |
|
self.doc_dir = doc_dir |
|
self.vector_db_path = vector_db_path |
|
self.embed_model = embed_model |
|
self.build_vector_db() |
|
|
|
def build_vector_db(self): |
|
|
|
if isinstance(self.doc_dir, list): |
|
general_docs = [] |
|
json_docs = [] |
|
md_docs = [] |
|
for doc_dir in self.doc_dir: |
|
|
|
loader = DirectoryLoader(doc_dir, glob="**/*.[!json!md]*") |
|
tmp_docs = loader.load() |
|
general_docs.extend(tmp_docs) |
|
|
|
for json_file in Path(doc_dir).rglob("*.json"): |
|
loader = JSONLoader( |
|
file_path=str(json_file), |
|
jq_schema=".[] | {spk: .spk, text: .text}", |
|
text_content=False) |
|
|
|
data = loader.load() |
|
for iidx in range(len(data)): |
|
data[iidx].page_content = bytes(data[iidx].page_content, "utf-8").decode("unicode_escape") |
|
json_docs.extend(data) |
|
|
|
|
|
headers_to_split_on = [ |
|
("#", "Header 1"), |
|
("##", "Header 2"), |
|
("###", "Header 3"), |
|
] |
|
for md_file in Path(doc_dir).rglob("*.md"): |
|
with open(md_file, 'r') as f: |
|
content = f.read() |
|
|
|
|
|
|
|
markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on) |
|
md_header_splits = markdown_splitter.split_text(content) |
|
md_docs.extend(md_header_splits) |
|
|
|
|
|
|
|
|
|
else: |
|
loader = DirectoryLoader(self.doc_dir, glob="**/*.*") |
|
docs = loader.load() |
|
|
|
|
|
if len(general_docs) > 0: |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=500, |
|
chunk_overlap=50 |
|
) |
|
chunks = text_splitter.split_documents(docs) |
|
else: |
|
chunks = json_docs + md_docs |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name=self.embed_model) |
|
self.vector_db = FAISS.from_documents(chunks, embeddings) |
|
self.vector_db.save_local(self.vector_db_path) |
|
|
|
def reload_knowledge_base(self, target_doc_dir): |
|
self.doc_dir = target_doc_dir |
|
self.build_vector_db() |
|
|
|
|
|
|
|
|
|
|
|
class LocalRAG_new: |
|
|
|
def __init__(self, |
|
rag_top_k=3, |
|
doc_dir="rag/kb/BIGOLIVE及公司介绍/", |
|
vector_db_path="rag/vector_db/", |
|
embed_model_path="princeton-nlp/sup-simcse-bert-large-uncased", |
|
device=torch.device('cuda:2')): |
|
self.rag_top_k = rag_top_k |
|
self.doc_dir = doc_dir |
|
self.kb_name = '_'.join([Path(doc_dir[i]).name for i in range(len(doc_dir))]) |
|
self.embed_model_name = Path(embed_model_path).name |
|
self.vector_db_path = vector_db_path |
|
self.embed_model = embed_model_path |
|
|
|
self.device = device |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.embed_model) |
|
self.embed_model = AutoModel.from_pretrained(self.embed_model).to(device) |
|
self.vector_db = None |
|
self._vector_db = None |
|
self.build_vector_db() |
|
|
|
class VectorDB: |
|
def __init__(self, rag): |
|
self._data = rag._vector_db |
|
self.rag = rag |
|
|
|
def similarity_search(self, query, k): |
|
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
query_token = self.rag.tokenizer(query, padding=True, truncation=False, return_tensors="pt").to(self.rag.device) |
|
query_embed = self.rag.embed_model(**query_token)['last_hidden_state'].mean(dim=1) |
|
sim_query = F.cosine_similarity(query_embed.repeat(len(self._data['embeds']), 1), self._data['embeds'], dim=1, eps=1e-8) |
|
max_ids_query = torch.argsort(sim_query, descending=True)[:self.rag.rag_top_k].cpu().detach().numpy() |
|
return list(zip(np.array(self._data['chunks'])[max_ids_query], sim_query[max_ids_query])) |
|
|
|
def build_vector_db(self): |
|
|
|
if isinstance(self.doc_dir, list): |
|
docs = [] |
|
for doc_dir in self.doc_dir: |
|
loader = DirectoryLoader(doc_dir, glob="**/*.[!json!md]*") |
|
tmp_docs = loader.load() |
|
docs.extend(tmp_docs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
headers_to_split_on = [ |
|
("#", "Header 1"), |
|
("##", "Header 2"), |
|
("###", "Header 3"), |
|
] |
|
for md_file in Path(doc_dir).rglob("*.md"): |
|
with open(md_file, 'r') as f: |
|
content = f.read() |
|
|
|
|
|
|
|
markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on) |
|
md_header_splits = markdown_splitter.split_text(content) |
|
docs.extend(md_header_splits) |
|
|
|
|
|
|
|
|
|
else: |
|
loader = DirectoryLoader(self.doc_dir, glob="**/*.*") |
|
docs = loader.load() |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=500, |
|
chunk_overlap=50 |
|
) |
|
chunks = text_splitter.split_documents(docs) |
|
with torch.inference_mode(): |
|
chunk_and_embed = [] |
|
for chunk in chunks: |
|
chunk_token = self.tokenizer(chunk.page_content, padding=True, truncation=False, return_tensors="pt").to(self.device) |
|
chunk_embed = self.embed_model(**chunk_token)['last_hidden_state'].mean(dim=1) |
|
chunk_and_embed.append((chunk, chunk_embed)) |
|
all_chunks, all_embeds = list(zip(*chunk_and_embed)) |
|
all_chunks, all_embeds = list(all_chunks), list(all_embeds) |
|
all_embeds = torch.cat(all_embeds, dim=0) |
|
self._vector_db = {'chunks': all_chunks, 'embeds': all_embeds} |
|
self.vector_db = self.VectorDB(self) |
|
|
|
torch.save(self.vector_db, str(Path(self.vector_db_path) / f'{self.kb_name}_{self.embed_model_name}.pt')) |
|
|
|
def reload_knowledge_base(self, target_doc_dir): |
|
self.doc_dir = target_doc_dir |
|
self.build_vector_db() |
|
|
|
|
|
|
|
|
|
|
|
class CosPlayer: |
|
def __init__(self, description_file): |
|
self.update(description_file) |
|
|
|
def update(self, description_file): |
|
self.description_file = description_file |
|
with open(description_file, 'r') as f: |
|
all_lines = f.readlines() |
|
self.core_setting = ''.join(all_lines) |
|
self.characters_dir = Path(description_file).parent |
|
self.prologue_file = self.description_file.replace('/characters/', '/prologues/') |
|
if not Path(self.prologue_file).exists(): |
|
self.prologue_file = None |
|
|
|
def get_all_characters(self): |
|
return [str(i) for i in list(self.characters_dir.rglob('*.txt'))] |
|
|
|
def get_core_setting(self): |
|
return self.core_setting |
|
|
|
def get_prologue(self): |
|
if self.prologue_file: |
|
with open(self.prologue_file, 'r') as f: |
|
all_lines = f.readlines() |
|
return ''.join(all_lines) |
|
else: |
|
return None |
|
|
|
|
|
if __name__ == "__main__": |
|
rag = LocalRAG() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|