Spaces:
Configuration error
Configuration error
Upload 38 files
Browse files- .gitattributes +2 -0
- Chroma_db/readme.txt +0 -0
- Config/__pycache__/config.cpython-310.pyc +0 -0
- Config/config.py +14 -0
- Faiss_db/readme.txt +0 -0
- Faiss_db/sss1/index.faiss +0 -0
- Faiss_db/sss1/index.pkl +3 -0
- Neo4j/__pycache__/graph_extract.cpython-310.pyc +0 -0
- Neo4j/__pycache__/neo4j_op.cpython-310.pyc +0 -0
- Neo4j/graph_extract.py +69 -0
- Neo4j/neo4j_op.py +105 -0
- Ollama_api/__pycache__/ollama_api.cpython-310.pyc +0 -0
- Ollama_api/ollama_api.py +21 -0
- embeding/__pycache__/asr_utils.cpython-310.pyc +0 -0
- embeding/__pycache__/chromadb.cpython-310.pyc +0 -0
- embeding/__pycache__/elasticsearchStore.cpython-310.pyc +0 -0
- embeding/__pycache__/faissdb.cpython-310.pyc +0 -0
- embeding/asr_utils.py +17 -0
- embeding/chromadb.py +134 -0
- embeding/elasticsearchStore.py +147 -0
- embeding/faissdb.py +138 -0
- embeding/tmp.txt +2 -0
- img/graph-tool.png +3 -0
- img/readme.txt +1 -0
- img/zhu.png +3 -0
- img/zhuye.png +0 -0
- img//345/244/215/346/235/202/346/226/271/345/274/217.png +0 -0
- img//345/276/256/344/277/241/345/233/276/347/211/207_20240524180648.jpg +0 -0
- rag/__init__.py +0 -0
- rag/__pycache__/__init__.cpython-310.pyc +0 -0
- rag/__pycache__/config.cpython-310.pyc +0 -0
- rag/__pycache__/rag_class.cpython-310.pyc +0 -0
- rag/__pycache__/rerank.cpython-310.pyc +0 -0
- rag/__pycache__/rerank.cpython-39.pyc +0 -0
- rag/__pycache__/rerank_code.cpython-310.pyc +0 -0
- rag/rag_class.py +169 -0
- rag/rerank_code.py +21 -0
- test/__init__.py +0 -0
- test/graph2neo4j.py +25 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
img/graph-tool.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
img/zhu.png filter=lfs diff=lfs merge=lfs -text
|
Chroma_db/readme.txt
ADDED
File without changes
|
Config/__pycache__/config.cpython-310.pyc
ADDED
Binary file (362 Bytes). View file
|
|
Config/config.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 向量数据库选择 【chroma:1】 ,【faiss 2】,【ElasticsearchStore 3】
|
2 |
+
VECTOR_DB = 2
|
3 |
+
DB_directory = "./Chroma_db/"
|
4 |
+
if VECTOR_DB==2:
|
5 |
+
DB_directory ="./Faiss_db/"
|
6 |
+
elif VECTOR_DB==3:
|
7 |
+
DB_directory = "es"
|
8 |
+
|
9 |
+
# 配置neo4j
|
10 |
+
neo4j_host = "bolt://localhost:7687"
|
11 |
+
neo4j_name = "neo4j"
|
12 |
+
neo4j_pwd = "12345678"
|
13 |
+
# 测试了 llama3:8b,gemma2:9b,qwen2:7b,glm4:9b,arcee-ai/arcee-agent:latest 目前来看 qwen2:7 效果最好
|
14 |
+
neo4j_model = "qwen2:7b"
|
Faiss_db/readme.txt
ADDED
File without changes
|
Faiss_db/sss1/index.faiss
ADDED
Binary file (82 kB). View file
|
|
Faiss_db/sss1/index.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2bb588f4bd46218f42b045c42163bdcf3cc76a19e37458823ceaeaf8a1454e3b
|
3 |
+
size 9362
|
Neo4j/__pycache__/graph_extract.cpython-310.pyc
ADDED
Binary file (2.51 kB). View file
|
|
Neo4j/__pycache__/neo4j_op.cpython-310.pyc
ADDED
Binary file (3.89 kB). View file
|
|
Neo4j/graph_extract.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.llms import Ollama
|
2 |
+
from Config.config import neo4j_model
|
3 |
+
|
4 |
+
# 测试了 llama3:8b,gemma2:9b,qwen2:7b,glm4:9b,arcee-ai/arcee-agent:latest 目前来看 qwen2:7 效果最好
|
5 |
+
llm = Ollama(model=neo4j_model)
|
6 |
+
|
7 |
+
json_example = {'edges': [
|
8 |
+
{
|
9 |
+
'label': 'label 1',
|
10 |
+
'source': 'source 1',
|
11 |
+
'target': 'target 1'},
|
12 |
+
{
|
13 |
+
'label': 'label 1',
|
14 |
+
'source': 'source 1',
|
15 |
+
'target': 'target 1'}
|
16 |
+
],
|
17 |
+
'nodes': [{'name': 'label 1'},
|
18 |
+
{'name': 'label 2'},
|
19 |
+
{'name': 'label 3'}]
|
20 |
+
}
|
21 |
+
|
22 |
+
__retriever_prompt = f"""
|
23 |
+
您是一名专门从事知识图谱创建的人工智能专家,目标是根据给定的输入或请求捕获关系。
|
24 |
+
基于各种形式的用户输入,如段落、电子邮件、文本文件等。
|
25 |
+
你的任务是根据输入创建一个知识图谱。
|
26 |
+
nodes中每个元素只有一个name参数,name对应的值是一个实体,实体来自输入的词语或短语。
|
27 |
+
edges还必须有一个label参数,其中label是输入中的直接词语或短语,edges中的source和target取自nodes中的name。
|
28 |
+
|
29 |
+
仅使用JSON进行响应,其格式可以在python中进行jsonify,并直接输入cy.add(data),
|
30 |
+
您可以参考给定的示例:{json_example}。存储node和edge的数组中,最后一个元素后边不要有逗号,
|
31 |
+
确保边的目标和源与现有节点匹配。
|
32 |
+
不要在JSON的上方和下方包含markdown三引号,直接用花括号括起来。
|
33 |
+
"""
|
34 |
+
|
35 |
+
|
36 |
+
def generate_graph_info(raw_text: str) -> str | None:
|
37 |
+
"""
|
38 |
+
generate graph info from raw text
|
39 |
+
:param raw_text:
|
40 |
+
:return:
|
41 |
+
"""
|
42 |
+
messages = [
|
43 |
+
{"role": "system", "content": "你现在扮演信息抽取的角色,要求根据用户输入和AI的回答,正确提取出信息,记得不多对实体进行翻译。"},
|
44 |
+
{"role": "user", "content": raw_text},
|
45 |
+
{"role": "user", "content": __retriever_prompt}
|
46 |
+
]
|
47 |
+
print("解析中....")
|
48 |
+
for i in range(3):
|
49 |
+
graph_info_result = llm.invoke(messages)
|
50 |
+
if len(graph_info_result) < 10:
|
51 |
+
print("-------", i, "-------------------")
|
52 |
+
continue
|
53 |
+
else:
|
54 |
+
break
|
55 |
+
print(graph_info_result)
|
56 |
+
return graph_info_result
|
57 |
+
|
58 |
+
|
59 |
+
def update_graph(raw_text):
|
60 |
+
# raw_text = request.json.get('text', '')
|
61 |
+
try:
|
62 |
+
result = generate_graph_info(raw_text)
|
63 |
+
if '```' in result:
|
64 |
+
graph_data = eval(result.split('```', 2)[1].replace("json", ''))
|
65 |
+
else:
|
66 |
+
graph_data = eval(str(result))
|
67 |
+
return graph_data
|
68 |
+
except Exception as e:
|
69 |
+
return {'error': f"Error parsing graph data: {str(e)}"}
|
Neo4j/neo4j_op.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from py2neo import Graph, Node, Relationship
|
2 |
+
from langchain_community.document_loaders import TextLoader,UnstructuredCSVLoader, UnstructuredPDFLoader,UnstructuredWordDocumentLoader,UnstructuredExcelLoader,UnstructuredMarkdownLoader
|
3 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
class KnowledgeGraph:
|
8 |
+
def __init__(self, uri, user, password):
|
9 |
+
self.graph = Graph(uri, auth=(user, password))
|
10 |
+
|
11 |
+
def parse_data(self,file):
|
12 |
+
if "txt" in file.lower() or "csv" in file.lower():
|
13 |
+
try:
|
14 |
+
loaders = UnstructuredCSVLoader(file)
|
15 |
+
data = loaders.load()
|
16 |
+
except:
|
17 |
+
loaders = TextLoader(file,encoding="utf-8")
|
18 |
+
data = loaders.load()
|
19 |
+
if ".doc" in file.lower() or ".docx" in file.lower():
|
20 |
+
loaders = UnstructuredWordDocumentLoader(file)
|
21 |
+
data = loaders.load()
|
22 |
+
if "pdf" in file.lower():
|
23 |
+
loaders = UnstructuredPDFLoader(file)
|
24 |
+
data = loaders.load()
|
25 |
+
if ".xlsx" in file.lower():
|
26 |
+
loaders = UnstructuredExcelLoader(file)
|
27 |
+
data = loaders.load()
|
28 |
+
if ".md" in file.lower():
|
29 |
+
loaders = UnstructuredMarkdownLoader(file)
|
30 |
+
data = loaders.load()
|
31 |
+
return data
|
32 |
+
|
33 |
+
# 切分 数据
|
34 |
+
def split_files(self, files,chunk_size=500, chunk_overlap=100):
|
35 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
36 |
+
print("开始创建数据库 ....")
|
37 |
+
tmps = []
|
38 |
+
for file in files:
|
39 |
+
data = self.parse_data(file)
|
40 |
+
tmps.extend(data)
|
41 |
+
|
42 |
+
splits = text_splitter.split_documents(tmps)
|
43 |
+
|
44 |
+
return splits
|
45 |
+
|
46 |
+
def create_node(self, label, properties):
|
47 |
+
matcher = self.graph.nodes.match(label, **properties)
|
48 |
+
if matcher.first():
|
49 |
+
return matcher.first()
|
50 |
+
else:
|
51 |
+
node = Node(label, **properties)
|
52 |
+
self.graph.create(node)
|
53 |
+
return node
|
54 |
+
|
55 |
+
def create_relationship(self, label1, properties1, label2, properties2, relationship_type,
|
56 |
+
relationship_properties={}):
|
57 |
+
node1 = self.create_node(label1, properties1)
|
58 |
+
node2 = self.create_node(label2, properties2)
|
59 |
+
|
60 |
+
matcher = self.graph.match((node1, node2), r_type=relationship_type)
|
61 |
+
for rel in matcher:
|
62 |
+
if all(rel[key] == value for key, value in relationship_properties.items()):
|
63 |
+
return rel
|
64 |
+
|
65 |
+
relationship = Relationship(node1, relationship_type, node2, **relationship_properties)
|
66 |
+
self.graph.create(relationship)
|
67 |
+
return relationship
|
68 |
+
|
69 |
+
def delete_node(self, label, properties):
|
70 |
+
matcher = self.graph.nodes.match(label, **properties)
|
71 |
+
node = matcher.first()
|
72 |
+
if node:
|
73 |
+
self.graph.delete(node)
|
74 |
+
return True
|
75 |
+
return False
|
76 |
+
|
77 |
+
def update_node(self, label, identifier, updates):
|
78 |
+
matcher = self.graph.nodes.match(label, **identifier)
|
79 |
+
node = matcher.first()
|
80 |
+
if node:
|
81 |
+
for key, value in updates.items():
|
82 |
+
node[key] = value
|
83 |
+
self.graph.push(node)
|
84 |
+
return node
|
85 |
+
return None
|
86 |
+
|
87 |
+
def find_node(self, label, properties):
|
88 |
+
matcher = self.graph.nodes.match(label, **properties)
|
89 |
+
return list(matcher)
|
90 |
+
|
91 |
+
def create_nodes(self, label, properties_list):
|
92 |
+
nodes = []
|
93 |
+
for properties in properties_list:
|
94 |
+
node = self.create_node(label, properties)
|
95 |
+
nodes.append(node)
|
96 |
+
return nodes
|
97 |
+
|
98 |
+
def create_relationships(self, relationships):
|
99 |
+
created_relationships = []
|
100 |
+
for rel in relationships:
|
101 |
+
label1, properties1, label2, properties2, relationship_type = rel
|
102 |
+
relationship = self.create_relationship(label1, properties1, label2, properties2, relationship_type)
|
103 |
+
created_relationships.append(relationship)
|
104 |
+
return created_relationships
|
105 |
+
|
Ollama_api/__pycache__/ollama_api.cpython-310.pyc
ADDED
Binary file (721 Bytes). View file
|
|
Ollama_api/ollama_api.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import json
|
3 |
+
|
4 |
+
# 提供api获取ollama 模型列表
|
5 |
+
def get_llm():
|
6 |
+
respone = requests.get(url="http://localhost:11434/api/tags")
|
7 |
+
result = json.loads(respone.content)
|
8 |
+
llms = []
|
9 |
+
for llm in result["models"]:
|
10 |
+
if "code" not in llm["name"] and "embed" not in llm["name"]:
|
11 |
+
llms.append(llm["name"])
|
12 |
+
return llms
|
13 |
+
|
14 |
+
def get_embeding_model():
|
15 |
+
respone = requests.get(url="http://localhost:11434/api/tags")
|
16 |
+
result = json.loads(respone.content)
|
17 |
+
llms = []
|
18 |
+
for llm in result["models"]:
|
19 |
+
if "embed" in llm["name"]:
|
20 |
+
llms.append(llm["name"])
|
21 |
+
return llms
|
embeding/__pycache__/asr_utils.cpython-310.pyc
ADDED
Binary file (634 Bytes). View file
|
|
embeding/__pycache__/chromadb.cpython-310.pyc
ADDED
Binary file (3.91 kB). View file
|
|
embeding/__pycache__/elasticsearchStore.cpython-310.pyc
ADDED
Binary file (4.18 kB). View file
|
|
embeding/__pycache__/faissdb.cpython-310.pyc
ADDED
Binary file (4.21 kB). View file
|
|
embeding/asr_utils.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#coding:utf-8
|
2 |
+
from funasr import AutoModel
|
3 |
+
# paraformer-zh is a multi-functional asr model
|
4 |
+
# use vad, punc, spk or not as you need
|
5 |
+
model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc",
|
6 |
+
# spk_model="cam++"
|
7 |
+
)
|
8 |
+
def get_spk_txt(file):
|
9 |
+
res = model.generate(input=file,
|
10 |
+
batch_size_s=300,
|
11 |
+
hotword='魔搭')
|
12 |
+
print(res[0]["text"])
|
13 |
+
fw = "embeding/tmp.txt"
|
14 |
+
f = open(fw,"w",encoding="utf-8")
|
15 |
+
f.write('"context"\n'+res[0]["text"])
|
16 |
+
f.close()
|
17 |
+
return fw
|
embeding/chromadb.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.vectorstores import Chroma
|
2 |
+
from langchain_community.embeddings import OllamaEmbeddings
|
3 |
+
from langchain_community.document_loaders import TextLoader,UnstructuredCSVLoader, UnstructuredPDFLoader,UnstructuredWordDocumentLoader,UnstructuredExcelLoader,UnstructuredMarkdownLoader
|
4 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
5 |
+
from .asr_utils import get_spk_txt
|
6 |
+
|
7 |
+
class ChromaDB():
|
8 |
+
def __init__(self, embedding="mofanke/acge_text_embedding:latest", persist_directory="./Chroma_db/"):
|
9 |
+
|
10 |
+
self.embedding = OllamaEmbeddings(model=embedding)
|
11 |
+
self.persist_directory = persist_directory
|
12 |
+
self.chromadb = Chroma(persist_directory=persist_directory)
|
13 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=50)
|
14 |
+
|
15 |
+
def parse_data(self,file):
|
16 |
+
if "txt" in file.lower() or "csv" in file.lower():
|
17 |
+
try:
|
18 |
+
loaders = UnstructuredCSVLoader(file)
|
19 |
+
data = loaders.load()
|
20 |
+
except:
|
21 |
+
loaders = TextLoader(file,encoding="utf-8")
|
22 |
+
data = loaders.load()
|
23 |
+
if ".doc" in file.lower() or ".docx" in file.lower():
|
24 |
+
loaders = UnstructuredWordDocumentLoader(file)
|
25 |
+
data = loaders.load()
|
26 |
+
if "pdf" in file.lower():
|
27 |
+
loaders = UnstructuredPDFLoader(file)
|
28 |
+
data = loaders.load()
|
29 |
+
if ".xlsx" in file.lower():
|
30 |
+
loaders = UnstructuredExcelLoader(file)
|
31 |
+
data = loaders.load()
|
32 |
+
if ".md" in file.lower():
|
33 |
+
loaders = UnstructuredMarkdownLoader(file)
|
34 |
+
data = loaders.load()
|
35 |
+
if "mp3" in file.lower() or "mp4" in file.lower() or "wav" in file.lower():
|
36 |
+
# 语音解析成文字
|
37 |
+
fw = get_spk_txt(file)
|
38 |
+
loaders = UnstructuredCSVLoader(fw)
|
39 |
+
data = loaders.load()
|
40 |
+
tmp = []
|
41 |
+
for i in data:
|
42 |
+
i.metadata["source"] = file
|
43 |
+
tmp.append(i)
|
44 |
+
data = tmp
|
45 |
+
return data
|
46 |
+
|
47 |
+
# 创建 新的collection 并且初始化
|
48 |
+
def create_collection(self, files, c_name,chunk_size=200, chunk_overlap=50):
|
49 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
50 |
+
print("开始创建数据库 ....")
|
51 |
+
tmps = []
|
52 |
+
for file in files:
|
53 |
+
data = self.parse_data(file)
|
54 |
+
tmps.extend(data)
|
55 |
+
|
56 |
+
splits = self.text_splitter.split_documents(tmps)
|
57 |
+
|
58 |
+
vectorstore = self.chromadb.from_documents(documents=splits, collection_name=c_name,
|
59 |
+
embedding=self.embedding, persist_directory=self.persist_directory)
|
60 |
+
print("数据块总量:", vectorstore._collection.count())
|
61 |
+
|
62 |
+
return vectorstore
|
63 |
+
|
64 |
+
# 添加 数据到已有数据库
|
65 |
+
def add_chroma(self, files, c_name,chunk_size=200, chunk_overlap=50):
|
66 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
67 |
+
print("开始添加文件...")
|
68 |
+
tmps = []
|
69 |
+
for file in files:
|
70 |
+
data = self.parse_data(file)
|
71 |
+
tmps.extend(data)
|
72 |
+
|
73 |
+
splits = self.text_splitter.split_documents(tmps)
|
74 |
+
|
75 |
+
vectorstore = Chroma(persist_directory=self.persist_directory, collection_name=c_name,
|
76 |
+
embedding_function=self.embedding)
|
77 |
+
vectorstore.add_documents(splits)
|
78 |
+
print("数据块总量:", vectorstore._collection.count())
|
79 |
+
|
80 |
+
return vectorstore
|
81 |
+
|
82 |
+
# 删除 某个collection中的 某个文件
|
83 |
+
def del_files(self, del_files_name, c_name):
|
84 |
+
|
85 |
+
vectorstore = self.chromadb._client.get_collection(c_name)
|
86 |
+
del_ids = []
|
87 |
+
vec_dict = vectorstore.get()
|
88 |
+
for id, md in zip(vec_dict["ids"], vec_dict["metadatas"]):
|
89 |
+
for dl in del_files_name:
|
90 |
+
if dl in md["source"]:
|
91 |
+
del_ids.append(id)
|
92 |
+
vectorstore.delete(ids=del_ids)
|
93 |
+
print("数据块总量:", vectorstore.count())
|
94 |
+
|
95 |
+
return vectorstore
|
96 |
+
|
97 |
+
# 删除某个 知识库 collection
|
98 |
+
def delete_collection(self, c_name):
|
99 |
+
|
100 |
+
self.chromadb._client.delete_collection(c_name)
|
101 |
+
|
102 |
+
# 获取目前所有 collection
|
103 |
+
def get_all_collections_name(self):
|
104 |
+
cl_names = []
|
105 |
+
|
106 |
+
test = self.chromadb._client.list_collections()
|
107 |
+
for i in range(len(test)):
|
108 |
+
cl_names.append(test[i].name)
|
109 |
+
return cl_names
|
110 |
+
|
111 |
+
# 获取 collection中的所有文件
|
112 |
+
def get_collcetion_content_files(self, c_name):
|
113 |
+
vectorstore = self.chromadb._client.get_collection(c_name)
|
114 |
+
c_files = []
|
115 |
+
vec_dict = vectorstore.get()
|
116 |
+
for md in vec_dict["metadatas"]:
|
117 |
+
c_files.append(md["source"])
|
118 |
+
return list(set(c_files))
|
119 |
+
|
120 |
+
|
121 |
+
# if __name__ == "__main__":
|
122 |
+
# chromadb = ChromaDB()
|
123 |
+
# c_name = "sss3"
|
124 |
+
#
|
125 |
+
# print(chromadb.get_all_collections_name())
|
126 |
+
# chromadb.create_collection(["data/���内科学.txt", "data/jl.pdf"], c_name=c_name)
|
127 |
+
# print(chromadb.get_all_collections_name())
|
128 |
+
# chromadb.add_chroma(["data/儿科学.txt"], c_name=c_name)
|
129 |
+
# print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
|
130 |
+
# chromadb.del_files(["data/肾内科学.txt"], c_name=c_name)
|
131 |
+
# print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
|
132 |
+
# print(chromadb.get_all_collections_name())
|
133 |
+
# chromadb.delete_collection(c_name=c_name)
|
134 |
+
# print(chromadb.get_all_collections_name())
|
embeding/elasticsearchStore.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from elasticsearch import Elasticsearch
|
2 |
+
from langchain_elasticsearch.vectorstores import ElasticsearchStore
|
3 |
+
from langchain_community.embeddings import OllamaEmbeddings
|
4 |
+
from langchain_community.document_loaders import TextLoader, UnstructuredCSVLoader, UnstructuredPDFLoader, \
|
5 |
+
UnstructuredWordDocumentLoader, UnstructuredExcelLoader, UnstructuredMarkdownLoader
|
6 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
+
from .asr_utils import get_spk_txt
|
8 |
+
import requests
|
9 |
+
|
10 |
+
|
11 |
+
class ElsStore():
|
12 |
+
def __init__(self, embedding="mofanke/acge_text_embedding:latest", es_url="http://localhost:9200",
|
13 |
+
index_name='test_index'):
|
14 |
+
self.embedding = OllamaEmbeddings(model=embedding)
|
15 |
+
self.es_url = es_url
|
16 |
+
self.elastic_vector_search = ElasticsearchStore(
|
17 |
+
es_url=self.es_url,
|
18 |
+
index_name=index_name,
|
19 |
+
embedding=self.embedding
|
20 |
+
)
|
21 |
+
|
22 |
+
def parse_data(self, file):
|
23 |
+
if "txt" in file.lower() or "csv" in file.lower():
|
24 |
+
try:
|
25 |
+
loaders = UnstructuredCSVLoader(file)
|
26 |
+
data = loaders.load()
|
27 |
+
except:
|
28 |
+
loaders = TextLoader(file, encoding="utf-8")
|
29 |
+
data = loaders.load()
|
30 |
+
if ".doc" in file.lower() or ".docx" in file.lower():
|
31 |
+
loaders = UnstructuredWordDocumentLoader(file)
|
32 |
+
data = loaders.load()
|
33 |
+
if "pdf" in file.lower():
|
34 |
+
loaders = UnstructuredPDFLoader(file)
|
35 |
+
data = loaders.load()
|
36 |
+
if ".xlsx" in file.lower():
|
37 |
+
loaders = UnstructuredExcelLoader(file)
|
38 |
+
data = loaders.load()
|
39 |
+
if ".md" in file.lower():
|
40 |
+
loaders = UnstructuredMarkdownLoader(file)
|
41 |
+
data = loaders.load()
|
42 |
+
if "mp3" in file.lower() or "mp4" in file.lower() or "wav" in file.lower():
|
43 |
+
# 语音解析成文字
|
44 |
+
fw = get_spk_txt(file)
|
45 |
+
loaders = UnstructuredCSVLoader(fw)
|
46 |
+
data = loaders.load()
|
47 |
+
tmp = []
|
48 |
+
for i in data:
|
49 |
+
i.metadata["source"] = file
|
50 |
+
tmp.append(i)
|
51 |
+
data = tmp
|
52 |
+
return data
|
53 |
+
|
54 |
+
def get_count(self, c_name):
|
55 |
+
# 获取index-anme中的数据块数
|
56 |
+
|
57 |
+
# 初始化 Elasticsearch 客户端
|
58 |
+
es = Elasticsearch([{
|
59 |
+
'host': self.es_url.split(":")[1][2:],
|
60 |
+
'port': int(self.es_url.split(":")[2]),
|
61 |
+
'scheme': 'http' # 指定使用的协议
|
62 |
+
}])
|
63 |
+
|
64 |
+
# 指定索引名称
|
65 |
+
index_name = c_name
|
66 |
+
|
67 |
+
# 获取文档总数
|
68 |
+
response = es.count(index=index_name)
|
69 |
+
|
70 |
+
# 输出文档总数
|
71 |
+
return response['count']
|
72 |
+
|
73 |
+
# 创建 新的index_name 并且初始化
|
74 |
+
def create_collection(self, files, c_name, chunk_size=200, chunk_overlap=50):
|
75 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
76 |
+
print("开始创建数据库 ....")
|
77 |
+
tmps = []
|
78 |
+
for file in files:
|
79 |
+
data = self.parse_data(file)
|
80 |
+
tmps.extend(data)
|
81 |
+
|
82 |
+
splits = self.text_splitter.split_documents(tmps)
|
83 |
+
|
84 |
+
self.elastic_vector_search = ElasticsearchStore.from_documents(
|
85 |
+
documents=splits,
|
86 |
+
embedding=self.embedding,
|
87 |
+
es_url=self.es_url,
|
88 |
+
index_name=c_name,
|
89 |
+
)
|
90 |
+
|
91 |
+
self.elastic_vector_search.client.indices.refresh(index=c_name)
|
92 |
+
|
93 |
+
print("数据块总量:", self.get_count(c_name))
|
94 |
+
|
95 |
+
return self.elastic_vector_search
|
96 |
+
|
97 |
+
# 添加 数据到已有数据库
|
98 |
+
def add_chroma(self, files, c_name, chunk_size=200, chunk_overlap=50):
|
99 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
100 |
+
print("开始添加文件...")
|
101 |
+
tmps = []
|
102 |
+
for file in files:
|
103 |
+
data = self.parse_data(file)
|
104 |
+
tmps.extend(data)
|
105 |
+
|
106 |
+
splits = self.text_splitter.split_documents(tmps)
|
107 |
+
|
108 |
+
self.elastic_vector_search = ElasticsearchStore(
|
109 |
+
es_url=self.es_url,
|
110 |
+
index_name=c_name,
|
111 |
+
embedding=self.embedding
|
112 |
+
)
|
113 |
+
self.elastic_vector_search.add_documents(splits)
|
114 |
+
self.elastic_vector_search.client.indices.refresh(index=c_name)
|
115 |
+
print("数据块总量:", self.get_count(c_name))
|
116 |
+
|
117 |
+
return self.elastic_vector_search
|
118 |
+
|
119 |
+
# 删除某个 知识库 collection
|
120 |
+
def delete_collection(self, c_name):
|
121 |
+
url = self.es_url + "/" + c_name
|
122 |
+
# 发送 DELETE 请求
|
123 |
+
response = requests.delete(url)
|
124 |
+
|
125 |
+
# 检查响应状态码
|
126 |
+
if response.status_code == 200:
|
127 |
+
return f"索引 'test-basic1' 已成功删除。"
|
128 |
+
elif response.status_code == 404:
|
129 |
+
return f"索引 'test-basic1' 不存在。"
|
130 |
+
else:
|
131 |
+
return f"删除索引时出错: {response.status_code}, {response.text}"
|
132 |
+
|
133 |
+
# 获取目前所有 index_names
|
134 |
+
def get_all_collections_name(self):
|
135 |
+
indices = self.elastic_vector_search.client.indices.get_alias()
|
136 |
+
index_names = list(indices.keys())
|
137 |
+
|
138 |
+
return index_names
|
139 |
+
|
140 |
+
def get_collcetion_content_files(self,c_name):
|
141 |
+
return []
|
142 |
+
|
143 |
+
# 删除 某个collection中的 某个文件
|
144 |
+
def del_files(self, del_files_name, c_name):
|
145 |
+
return None
|
146 |
+
|
147 |
+
|
embeding/faissdb.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.vectorstores import FAISS
|
2 |
+
from langchain_community.embeddings import OllamaEmbeddings
|
3 |
+
from langchain_community.document_loaders import TextLoader,UnstructuredCSVLoader, UnstructuredPDFLoader,UnstructuredWordDocumentLoader,UnstructuredExcelLoader,UnstructuredMarkdownLoader
|
4 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
5 |
+
import shutil
|
6 |
+
import os
|
7 |
+
from .asr_utils import get_spk_txt
|
8 |
+
|
9 |
+
|
10 |
+
class FaissDB():
|
11 |
+
def __init__(self, embedding="mofanke/acge_text_embedding:latest", persist_directory="./Faiss_db/"):
|
12 |
+
|
13 |
+
self.embedding = OllamaEmbeddings(model=embedding)
|
14 |
+
self.persist_directory = persist_directory
|
15 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50, add_start_index=True)
|
16 |
+
|
17 |
+
def parse_data(self,file):
|
18 |
+
if "txt" in file.lower() or "csv" in file.lower():
|
19 |
+
try:
|
20 |
+
loaders = UnstructuredCSVLoader(file)
|
21 |
+
data = loaders.load()
|
22 |
+
except:
|
23 |
+
loaders = TextLoader(file,encoding="utf-8")
|
24 |
+
data = loaders.load()
|
25 |
+
if ".doc" in file.lower() or ".docx" in file.lower():
|
26 |
+
loaders = UnstructuredWordDocumentLoader(file)
|
27 |
+
data = loaders.load()
|
28 |
+
if "pdf" in file.lower():
|
29 |
+
loaders = UnstructuredPDFLoader(file)
|
30 |
+
data = loaders.load()
|
31 |
+
if ".xlsx" in file.lower():
|
32 |
+
loaders = UnstructuredExcelLoader(file)
|
33 |
+
data = loaders.load()
|
34 |
+
if ".md" in file.lower():
|
35 |
+
loaders = UnstructuredMarkdownLoader(file)
|
36 |
+
data = loaders.load()
|
37 |
+
if "mp3" in file.lower() or "mp4" in file.lower() or "wav" in file.lower():
|
38 |
+
# 语音解析成文字
|
39 |
+
fw = get_spk_txt(file)
|
40 |
+
loaders = UnstructuredCSVLoader(fw)
|
41 |
+
data = loaders.load()
|
42 |
+
tmp = []
|
43 |
+
for i in data:
|
44 |
+
i.metadata["source"] = file
|
45 |
+
tmp.append(i)
|
46 |
+
data = tmp
|
47 |
+
return data
|
48 |
+
|
49 |
+
# 创建 新的collection 并且初始化
|
50 |
+
def create_collection(self, files, c_name,chunk_size=200, chunk_overlap=50):
|
51 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
52 |
+
print("开始创建数据库 ....")
|
53 |
+
tmps = []
|
54 |
+
for file in files:
|
55 |
+
data = self.parse_data(file)
|
56 |
+
tmps.extend(data)
|
57 |
+
|
58 |
+
splits = self.text_splitter.split_documents(tmps)
|
59 |
+
|
60 |
+
vectorstore = FAISS.from_documents(documents=splits,
|
61 |
+
embedding=self.embedding)
|
62 |
+
vectorstore.save_local(self.persist_directory + c_name)
|
63 |
+
print("数据块总量:", vectorstore.index.ntotal)
|
64 |
+
|
65 |
+
return vectorstore
|
66 |
+
|
67 |
+
# 添加 数据到已有数据库
|
68 |
+
def add_chroma(self, files, c_name,chunk_size=200, chunk_overlap=50):
|
69 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
70 |
+
print("开始添加文件...")
|
71 |
+
tmps = []
|
72 |
+
for file in files:
|
73 |
+
data = self.parse_data(file)
|
74 |
+
tmps.extend(data)
|
75 |
+
|
76 |
+
splits = self.text_splitter.split_documents(tmps)
|
77 |
+
|
78 |
+
vectorstore = FAISS.load_local(folder_path=self.persist_directory + c_name, embeddings=self.embedding,
|
79 |
+
allow_dangerous_deserialization=True)
|
80 |
+
vectorstore.add_documents(documents=splits)
|
81 |
+
vectorstore.save_local("Faiss_db/" + c_name)
|
82 |
+
print("数据块总量:", vectorstore.index.ntotal)
|
83 |
+
|
84 |
+
return vectorstore
|
85 |
+
|
86 |
+
# 删除 某个collection中的 某个文件
|
87 |
+
def del_files(self, del_files_name, c_name):
|
88 |
+
|
89 |
+
vectorstore = FAISS.load_local(folder_path=self.persist_directory + c_name, embeddings=self.embedding,
|
90 |
+
allow_dangerous_deserialization=True)
|
91 |
+
del_ids = []
|
92 |
+
vec_dict = vectorstore.docstore._dict
|
93 |
+
for id, md in vec_dict.items():
|
94 |
+
for dl in del_files_name:
|
95 |
+
if dl in md.metadata["source"]:
|
96 |
+
del_ids.append(id)
|
97 |
+
vectorstore.delete(ids=del_ids)
|
98 |
+
vectorstore.save_local(self.persist_directory + c_name)
|
99 |
+
print("数据块总量:", vectorstore.index.ntotal)
|
100 |
+
|
101 |
+
return vectorstore
|
102 |
+
|
103 |
+
# 删除某个 知识库 collection
|
104 |
+
def delete_collection(self, c_name):
|
105 |
+
shutil.rmtree(self.persist_directory + c_name)
|
106 |
+
|
107 |
+
# 获取目前所有 collection
|
108 |
+
def get_all_collections_name(self):
|
109 |
+
cl_names = [i for i in os.listdir(self.persist_directory) if os.path.isdir(self.persist_directory+i)]
|
110 |
+
|
111 |
+
return cl_names
|
112 |
+
|
113 |
+
# 获取 collection中的所有文件
|
114 |
+
def get_collcetion_content_files(self, c_name):
|
115 |
+
vectorstore = FAISS.load_local(folder_path=self.persist_directory + c_name, embeddings=self.embedding,
|
116 |
+
allow_dangerous_deserialization=True)
|
117 |
+
c_files = []
|
118 |
+
vec_dict = vectorstore.docstore._dict
|
119 |
+
for _, md in vec_dict.items():
|
120 |
+
c_files.append(md.metadata["source"])
|
121 |
+
|
122 |
+
return list(set(c_files))
|
123 |
+
|
124 |
+
|
125 |
+
# if __name__ == "__main__":
|
126 |
+
# chromadb = FaissDB()
|
127 |
+
# c_name = "sss3"
|
128 |
+
#
|
129 |
+
# print(chromadb.get_all_collections_name())
|
130 |
+
# chromadb.create_collection(["data/jl.txt", "data/jl.pdf"], c_name=c_name)
|
131 |
+
# print(chromadb.get_all_collections_name())
|
132 |
+
# chromadb.add_chroma(["data/tmp.txt"], c_name=c_name)
|
133 |
+
# print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
|
134 |
+
# chromadb.del_files(["data/tmp.txt"], c_name=c_name)
|
135 |
+
# print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
|
136 |
+
# print(chromadb.get_all_collections_name())
|
137 |
+
# chromadb.delete_collection(c_name=c_name)
|
138 |
+
# print(chromadb.get_all_collections_name())
|
embeding/tmp.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
"context"
|
2 |
+
你是不是觉得自己说话的声音直来直去呢?现在告诉你一个主持人吐字的小秘密,那就是每个字在口腔当中像是翻跟头一样打一圈再出来。比如说故人西辞黄鹤楼,而不是故人西辞黄鹤楼。再比如说乌衣巷口夕阳斜,而不是乌衣巷口夕阳斜,你也试试看抖音。
|
img/graph-tool.png
ADDED
![]() |
Git LFS Details
|
img/readme.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
1
|
img/zhu.png
ADDED
![]() |
Git LFS Details
|
img/zhuye.png
ADDED
![]() |
img//345/244/215/346/235/202/346/226/271/345/274/217.png
ADDED
![]() |
img//345/276/256/344/277/241/345/233/276/347/211/207_20240524180648.jpg
ADDED
![]() |
rag/__init__.py
ADDED
File without changes
|
rag/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (132 Bytes). View file
|
|
rag/__pycache__/config.cpython-310.pyc
ADDED
Binary file (364 Bytes). View file
|
|
rag/__pycache__/rag_class.cpython-310.pyc
ADDED
Binary file (5.39 kB). View file
|
|
rag/__pycache__/rerank.cpython-310.pyc
ADDED
Binary file (878 Bytes). View file
|
|
rag/__pycache__/rerank.cpython-39.pyc
ADDED
Binary file (869 Bytes). View file
|
|
rag/__pycache__/rerank_code.cpython-310.pyc
ADDED
Binary file (883 Bytes). View file
|
|
rag/rag_class.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.vectorstores import Chroma,FAISS
|
2 |
+
from langchain_community.llms import Ollama
|
3 |
+
from langchain_core.output_parsers import StrOutputParser
|
4 |
+
from langchain_community.embeddings import OllamaEmbeddings
|
5 |
+
from langchain_core.runnables import RunnablePassthrough
|
6 |
+
from operator import itemgetter
|
7 |
+
from langchain.prompts import ChatPromptTemplate
|
8 |
+
from rerank_code import rerank_topn
|
9 |
+
from Config.config import VECTOR_DB,DB_directory
|
10 |
+
from langchain_elasticsearch.vectorstores import ElasticsearchStore
|
11 |
+
|
12 |
+
|
13 |
+
class RAG_class:
|
14 |
+
def __init__(self, model="qwen2:7b", embed="milkey/dmeta-embedding-zh:f16", c_name="sss1",
|
15 |
+
persist_directory="E:/pycode/jupyter_code/langGraph/sss2/chroma.sqlite3/",es_url="http://localhost:9200"):
|
16 |
+
template = """
|
17 |
+
根据上下文回答以下问题,不要自己发挥,要根据以下参考内容总结答案,如果以下内容无法得到答案,就返回无法根据参考内容获取答案,
|
18 |
+
|
19 |
+
参考内容为:{context}
|
20 |
+
|
21 |
+
问题: {question}
|
22 |
+
"""
|
23 |
+
|
24 |
+
self.prompts = ChatPromptTemplate.from_template(template)
|
25 |
+
|
26 |
+
# 使用 问题扩展+结果递归方式得到最终答案
|
27 |
+
template1 = """你是一个乐于助人的助手,可以生成与输入问题相关的多个子问题。
|
28 |
+
目标是将输入分解为一组可以单独回答的子问题/子问题。
|
29 |
+
生成多个与以下内容相关的搜索查询:{question}
|
30 |
+
输出4个相关问题,以换行符隔开:"""
|
31 |
+
self.prompt_questions = ChatPromptTemplate.from_template(template1)
|
32 |
+
|
33 |
+
# 构建 问答对
|
34 |
+
template2 = """
|
35 |
+
以下是您需要回答的问题:
|
36 |
+
|
37 |
+
\n--\n {question} \n---\n
|
38 |
+
|
39 |
+
以下是任何可用的背景问答对:
|
40 |
+
|
41 |
+
\n--\n {q_a_pairs} \n---\n
|
42 |
+
|
43 |
+
以下是与该问题相关的其他上下文:
|
44 |
+
|
45 |
+
\n--\n {context} \n---\n
|
46 |
+
|
47 |
+
使用以上上下文和背景问答对来回答问题,问题是:{question} ,答案是:
|
48 |
+
"""
|
49 |
+
self.decomposition_prompt = ChatPromptTemplate.from_template(template2)
|
50 |
+
|
51 |
+
self.llm = Ollama(model=model)
|
52 |
+
self.embeding = OllamaEmbeddings(model=embed)
|
53 |
+
if VECTOR_DB==1:
|
54 |
+
self.vectstore = Chroma(embedding_function=self.embeding, collection_name=c_name,
|
55 |
+
persist_directory=persist_directory)
|
56 |
+
elif VECTOR_DB ==2:
|
57 |
+
self.vectstore = FAISS.load_local(folder_path=persist_directory + c_name, embeddings=self.embeding,
|
58 |
+
allow_dangerous_deserialization=True)
|
59 |
+
elif VECTOR_DB ==3:
|
60 |
+
self.vectstore = ElasticsearchStore(
|
61 |
+
es_url=es_url,
|
62 |
+
index_name=c_name,
|
63 |
+
embedding=self.embeding
|
64 |
+
)
|
65 |
+
self.retriever = self.vectstore.as_retriever()
|
66 |
+
try:
|
67 |
+
if VECTOR_DB==1:
|
68 |
+
self.vectstore = Chroma(embedding_function=self.embeding, collection_name=c_name,
|
69 |
+
persist_directory=persist_directory)
|
70 |
+
elif VECTOR_DB ==2:
|
71 |
+
self.vectstore = FAISS.load_local(folder_path=persist_directory + c_name, embeddings=self.embeding,
|
72 |
+
allow_dangerous_deserialization=True)
|
73 |
+
elif VECTOR_DB ==3:
|
74 |
+
self.vectstore = ElasticsearchStore(
|
75 |
+
es_url=es_url,
|
76 |
+
index_name=c_name,
|
77 |
+
embedding=self.embeding
|
78 |
+
)
|
79 |
+
self.retriever = self.vectstore.as_retriever()
|
80 |
+
except Exception as e:
|
81 |
+
print("仅模型时无需加载数据库",e)
|
82 |
+
#
|
83 |
+
# Post-processing
|
84 |
+
def format_docs(self,docs):
|
85 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
86 |
+
# 传统方式召回,单问题召回,然后llm总结答案回答
|
87 |
+
def simple_chain(self,question):
|
88 |
+
_chain = (
|
89 |
+
{"context": self.retriever|self.format_docs,"question":RunnablePassthrough()}
|
90 |
+
|self.prompts
|
91 |
+
|self.llm
|
92 |
+
|StrOutputParser()
|
93 |
+
)
|
94 |
+
answer = _chain.invoke({"question":question})
|
95 |
+
return answer
|
96 |
+
|
97 |
+
def rerank_chain(self,question):
|
98 |
+
retriever = self.vectstore.as_retriever(search_kwargs={"k": 10})
|
99 |
+
docs = retriever.invoke(question)
|
100 |
+
docs = rerank_topn(question,docs,N=5)
|
101 |
+
_chain = (
|
102 |
+
self.prompts
|
103 |
+
| self.llm
|
104 |
+
| StrOutputParser()
|
105 |
+
)
|
106 |
+
answer = _chain.invoke({"context":self.format_docs(docs),"question": question})
|
107 |
+
return answer
|
108 |
+
|
109 |
+
def format_qa_pairs(self, question, answer):
|
110 |
+
formatted_string = ""
|
111 |
+
formatted_string += f"Question: {question}\nAnswer:{answer}\n\n"
|
112 |
+
return formatted_string
|
113 |
+
|
114 |
+
# 获取问题的 扩展问题
|
115 |
+
def decomposition_chain(self, question):
|
116 |
+
_chain = (
|
117 |
+
{"question": RunnablePassthrough()}
|
118 |
+
| self.prompt_questions
|
119 |
+
| self.llm
|
120 |
+
| StrOutputParser()
|
121 |
+
| (lambda x: x.split("\n"))
|
122 |
+
)
|
123 |
+
|
124 |
+
questions = _chain.invoke({"question": question}) + [question]
|
125 |
+
|
126 |
+
return questions
|
127 |
+
# 多问题递归召回,每次召回后,问题和答案同时作为下一次召回的参考,再次用新问题召回
|
128 |
+
def rag_chain(self, questions):
|
129 |
+
q_a_pairs = ""
|
130 |
+
for q in questions:
|
131 |
+
_chain = (
|
132 |
+
{"context": itemgetter("question") | self.retriever,
|
133 |
+
"question": itemgetter("question"),
|
134 |
+
"q_a_pairs": itemgetter("q_a_paris")
|
135 |
+
}
|
136 |
+
| self.decomposition_prompt
|
137 |
+
| self.llm
|
138 |
+
| StrOutputParser()
|
139 |
+
)
|
140 |
+
|
141 |
+
answer = _chain.invoke({"question": q, "q_a_paris": q_a_pairs})
|
142 |
+
q_a_pairs = self.format_qa_pairs(q, answer)
|
143 |
+
q_a_pairs = q_a_pairs + "\n----\n" + q_a_pairs
|
144 |
+
return answer
|
145 |
+
|
146 |
+
# 将聊天历史格式化为一个字符串
|
147 |
+
def format_chat_history(self,history):
|
148 |
+
formatted_history = ""
|
149 |
+
for role,content in history:
|
150 |
+
formatted_history += f"{role}: {content}\n"
|
151 |
+
return formatted_history
|
152 |
+
# 基于ollama大模型的大模型 多轮对话,不使用知识库的
|
153 |
+
def mult_chat(self,chat_history):
|
154 |
+
# 格式化聊天历史
|
155 |
+
formatted_history = self.format_chat_history(chat_history)
|
156 |
+
|
157 |
+
# 调用模型生成回复
|
158 |
+
response = self.llm.invoke(formatted_history)
|
159 |
+
return response
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
# if __name__ == "__main__":
|
164 |
+
# rag = RAG_class(model="deepseek-r1:14b")
|
165 |
+
# question = "人卫社官网网址是?"
|
166 |
+
# questions = rag.decomposition_chain(question)
|
167 |
+
# print(questions)
|
168 |
+
# answer = rag.rag_chain(questions)
|
169 |
+
# print(answer)
|
rag/rerank_code.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
3 |
+
|
4 |
+
tokenizer = AutoTokenizer.from_pretrained('E:\\model\\bge-reranker-large')
|
5 |
+
model = AutoModelForSequenceClassification.from_pretrained('E:\\model\\bge-reranker-large')
|
6 |
+
model.eval()
|
7 |
+
|
8 |
+
|
9 |
+
def rerank_topn(question,docs,N=5):
|
10 |
+
pairs = []
|
11 |
+
for i in docs:
|
12 |
+
pairs.append([question,i.page_content])
|
13 |
+
|
14 |
+
with torch.no_grad():
|
15 |
+
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
|
16 |
+
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
|
17 |
+
scores = scores.argsort().numpy()[::-1][:N]
|
18 |
+
bk = []
|
19 |
+
for i in scores:
|
20 |
+
bk.append(docs[i])
|
21 |
+
return bk
|
test/__init__.py
ADDED
File without changes
|
test/graph2neo4j.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import sys
|
3 |
+
sys.path.append(r"..//")#
|
4 |
+
from Neo4j.neo4j_op import KnowledgeGraph
|
5 |
+
from Neo4j.graph_extract import update_graph
|
6 |
+
from Config.config import neo4j_host,neo4j_name,neo4j_pwd
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
kg = KnowledgeGraph(neo4j_host,neo4j_name,neo4j_pwd)
|
11 |
+
|
12 |
+
|
13 |
+
if __name__ == "__main__":
|
14 |
+
|
15 |
+
text = """范冰冰,1981年9月16日生于山东青岛,毕业于上海师范大学谢晋影视艺术学院,中国女演员,歌手。
|
16 |
+
1998年参演电视剧《还珠格格》成名。2004年主演电影《手机》获得第27届大众电影百花奖最佳女演员奖。"""
|
17 |
+
res = update_graph(text)
|
18 |
+
# 批量创建节点
|
19 |
+
nodes = kg.create_nodes("node", res["nodes"])
|
20 |
+
print(nodes)
|
21 |
+
# 批量创建关系
|
22 |
+
relationships = kg.create_relationships([
|
23 |
+
("node", {"name": edge["source"]}, "node", {"name": edge["target"]}, edge["label"]) for edge in res["edges"]
|
24 |
+
])
|
25 |
+
print(relationships)
|