chengyingmo commited on
Commit
c604980
·
verified ·
1 Parent(s): ce2b87c

Upload 38 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ img/graph-tool.png filter=lfs diff=lfs merge=lfs -text
37
+ img/zhu.png filter=lfs diff=lfs merge=lfs -text
Chroma_db/readme.txt ADDED
File without changes
Config/__pycache__/config.cpython-310.pyc ADDED
Binary file (362 Bytes). View file
 
Config/config.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 向量数据库选择 【chroma:1】 ,【faiss 2】,【ElasticsearchStore 3】
2
+ VECTOR_DB = 2
3
+ DB_directory = "./Chroma_db/"
4
+ if VECTOR_DB==2:
5
+ DB_directory ="./Faiss_db/"
6
+ elif VECTOR_DB==3:
7
+ DB_directory = "es"
8
+
9
+ # 配置neo4j
10
+ neo4j_host = "bolt://localhost:7687"
11
+ neo4j_name = "neo4j"
12
+ neo4j_pwd = "12345678"
13
+ # 测试了 llama3:8b,gemma2:9b,qwen2:7b,glm4:9b,arcee-ai/arcee-agent:latest 目前来看 qwen2:7 效果最好
14
+ neo4j_model = "qwen2:7b"
Faiss_db/readme.txt ADDED
File without changes
Faiss_db/sss1/index.faiss ADDED
Binary file (82 kB). View file
 
Faiss_db/sss1/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2bb588f4bd46218f42b045c42163bdcf3cc76a19e37458823ceaeaf8a1454e3b
3
+ size 9362
Neo4j/__pycache__/graph_extract.cpython-310.pyc ADDED
Binary file (2.51 kB). View file
 
Neo4j/__pycache__/neo4j_op.cpython-310.pyc ADDED
Binary file (3.89 kB). View file
 
Neo4j/graph_extract.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.llms import Ollama
2
+ from Config.config import neo4j_model
3
+
4
+ # 测试了 llama3:8b,gemma2:9b,qwen2:7b,glm4:9b,arcee-ai/arcee-agent:latest 目前来看 qwen2:7 效果最好
5
+ llm = Ollama(model=neo4j_model)
6
+
7
+ json_example = {'edges': [
8
+ {
9
+ 'label': 'label 1',
10
+ 'source': 'source 1',
11
+ 'target': 'target 1'},
12
+ {
13
+ 'label': 'label 1',
14
+ 'source': 'source 1',
15
+ 'target': 'target 1'}
16
+ ],
17
+ 'nodes': [{'name': 'label 1'},
18
+ {'name': 'label 2'},
19
+ {'name': 'label 3'}]
20
+ }
21
+
22
+ __retriever_prompt = f"""
23
+ 您是一名专门从事知识图谱创建的人工智能专家,目标是根据给定的输入或请求捕获关系。
24
+ 基于各种形式的用户输入,如段落、电子邮件、文本文件等。
25
+ 你的任务是根据输入创建一个知识图谱。
26
+ nodes中每个元素只有一个name参数,name对应的值是一个实体,实体来自输入的词语或短语。
27
+ edges还必须有一个label参数,其中label是输入中的直接词语或短语,edges中的source和target取自nodes中的name。
28
+
29
+ 仅使用JSON进行响应,其格式可以在python中进行jsonify,并直接输入cy.add(data),
30
+ 您可以参考给定的示例:{json_example}。存储node和edge的数组中,最后一个元素后边不要有逗号,
31
+ 确保边的目标和源与现有节点匹配。
32
+ 不要在JSON的上方和下方包含markdown三引号,直接用花括号括起来。
33
+ """
34
+
35
+
36
+ def generate_graph_info(raw_text: str) -> str | None:
37
+ """
38
+ generate graph info from raw text
39
+ :param raw_text:
40
+ :return:
41
+ """
42
+ messages = [
43
+ {"role": "system", "content": "你现在扮演信息抽取的角色,要求根据用户输入和AI的回答,正确提取出信息,记得不多对实体进行翻译。"},
44
+ {"role": "user", "content": raw_text},
45
+ {"role": "user", "content": __retriever_prompt}
46
+ ]
47
+ print("解析中....")
48
+ for i in range(3):
49
+ graph_info_result = llm.invoke(messages)
50
+ if len(graph_info_result) < 10:
51
+ print("-------", i, "-------------------")
52
+ continue
53
+ else:
54
+ break
55
+ print(graph_info_result)
56
+ return graph_info_result
57
+
58
+
59
+ def update_graph(raw_text):
60
+ # raw_text = request.json.get('text', '')
61
+ try:
62
+ result = generate_graph_info(raw_text)
63
+ if '```' in result:
64
+ graph_data = eval(result.split('```', 2)[1].replace("json", ''))
65
+ else:
66
+ graph_data = eval(str(result))
67
+ return graph_data
68
+ except Exception as e:
69
+ return {'error': f"Error parsing graph data: {str(e)}"}
Neo4j/neo4j_op.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from py2neo import Graph, Node, Relationship
2
+ from langchain_community.document_loaders import TextLoader,UnstructuredCSVLoader, UnstructuredPDFLoader,UnstructuredWordDocumentLoader,UnstructuredExcelLoader,UnstructuredMarkdownLoader
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+
5
+
6
+
7
+ class KnowledgeGraph:
8
+ def __init__(self, uri, user, password):
9
+ self.graph = Graph(uri, auth=(user, password))
10
+
11
+ def parse_data(self,file):
12
+ if "txt" in file.lower() or "csv" in file.lower():
13
+ try:
14
+ loaders = UnstructuredCSVLoader(file)
15
+ data = loaders.load()
16
+ except:
17
+ loaders = TextLoader(file,encoding="utf-8")
18
+ data = loaders.load()
19
+ if ".doc" in file.lower() or ".docx" in file.lower():
20
+ loaders = UnstructuredWordDocumentLoader(file)
21
+ data = loaders.load()
22
+ if "pdf" in file.lower():
23
+ loaders = UnstructuredPDFLoader(file)
24
+ data = loaders.load()
25
+ if ".xlsx" in file.lower():
26
+ loaders = UnstructuredExcelLoader(file)
27
+ data = loaders.load()
28
+ if ".md" in file.lower():
29
+ loaders = UnstructuredMarkdownLoader(file)
30
+ data = loaders.load()
31
+ return data
32
+
33
+ # 切分 数据
34
+ def split_files(self, files,chunk_size=500, chunk_overlap=100):
35
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
36
+ print("开始创建数据库 ....")
37
+ tmps = []
38
+ for file in files:
39
+ data = self.parse_data(file)
40
+ tmps.extend(data)
41
+
42
+ splits = text_splitter.split_documents(tmps)
43
+
44
+ return splits
45
+
46
+ def create_node(self, label, properties):
47
+ matcher = self.graph.nodes.match(label, **properties)
48
+ if matcher.first():
49
+ return matcher.first()
50
+ else:
51
+ node = Node(label, **properties)
52
+ self.graph.create(node)
53
+ return node
54
+
55
+ def create_relationship(self, label1, properties1, label2, properties2, relationship_type,
56
+ relationship_properties={}):
57
+ node1 = self.create_node(label1, properties1)
58
+ node2 = self.create_node(label2, properties2)
59
+
60
+ matcher = self.graph.match((node1, node2), r_type=relationship_type)
61
+ for rel in matcher:
62
+ if all(rel[key] == value for key, value in relationship_properties.items()):
63
+ return rel
64
+
65
+ relationship = Relationship(node1, relationship_type, node2, **relationship_properties)
66
+ self.graph.create(relationship)
67
+ return relationship
68
+
69
+ def delete_node(self, label, properties):
70
+ matcher = self.graph.nodes.match(label, **properties)
71
+ node = matcher.first()
72
+ if node:
73
+ self.graph.delete(node)
74
+ return True
75
+ return False
76
+
77
+ def update_node(self, label, identifier, updates):
78
+ matcher = self.graph.nodes.match(label, **identifier)
79
+ node = matcher.first()
80
+ if node:
81
+ for key, value in updates.items():
82
+ node[key] = value
83
+ self.graph.push(node)
84
+ return node
85
+ return None
86
+
87
+ def find_node(self, label, properties):
88
+ matcher = self.graph.nodes.match(label, **properties)
89
+ return list(matcher)
90
+
91
+ def create_nodes(self, label, properties_list):
92
+ nodes = []
93
+ for properties in properties_list:
94
+ node = self.create_node(label, properties)
95
+ nodes.append(node)
96
+ return nodes
97
+
98
+ def create_relationships(self, relationships):
99
+ created_relationships = []
100
+ for rel in relationships:
101
+ label1, properties1, label2, properties2, relationship_type = rel
102
+ relationship = self.create_relationship(label1, properties1, label2, properties2, relationship_type)
103
+ created_relationships.append(relationship)
104
+ return created_relationships
105
+
Ollama_api/__pycache__/ollama_api.cpython-310.pyc ADDED
Binary file (721 Bytes). View file
 
Ollama_api/ollama_api.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+
4
+ # 提供api获取ollama 模型列表
5
+ def get_llm():
6
+ respone = requests.get(url="http://localhost:11434/api/tags")
7
+ result = json.loads(respone.content)
8
+ llms = []
9
+ for llm in result["models"]:
10
+ if "code" not in llm["name"] and "embed" not in llm["name"]:
11
+ llms.append(llm["name"])
12
+ return llms
13
+
14
+ def get_embeding_model():
15
+ respone = requests.get(url="http://localhost:11434/api/tags")
16
+ result = json.loads(respone.content)
17
+ llms = []
18
+ for llm in result["models"]:
19
+ if "embed" in llm["name"]:
20
+ llms.append(llm["name"])
21
+ return llms
embeding/__pycache__/asr_utils.cpython-310.pyc ADDED
Binary file (634 Bytes). View file
 
embeding/__pycache__/chromadb.cpython-310.pyc ADDED
Binary file (3.91 kB). View file
 
embeding/__pycache__/elasticsearchStore.cpython-310.pyc ADDED
Binary file (4.18 kB). View file
 
embeding/__pycache__/faissdb.cpython-310.pyc ADDED
Binary file (4.21 kB). View file
 
embeding/asr_utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #coding:utf-8
2
+ from funasr import AutoModel
3
+ # paraformer-zh is a multi-functional asr model
4
+ # use vad, punc, spk or not as you need
5
+ model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc",
6
+ # spk_model="cam++"
7
+ )
8
+ def get_spk_txt(file):
9
+ res = model.generate(input=file,
10
+ batch_size_s=300,
11
+ hotword='魔搭')
12
+ print(res[0]["text"])
13
+ fw = "embeding/tmp.txt"
14
+ f = open(fw,"w",encoding="utf-8")
15
+ f.write('"context"\n'+res[0]["text"])
16
+ f.close()
17
+ return fw
embeding/chromadb.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import Chroma
2
+ from langchain_community.embeddings import OllamaEmbeddings
3
+ from langchain_community.document_loaders import TextLoader,UnstructuredCSVLoader, UnstructuredPDFLoader,UnstructuredWordDocumentLoader,UnstructuredExcelLoader,UnstructuredMarkdownLoader
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from .asr_utils import get_spk_txt
6
+
7
+ class ChromaDB():
8
+ def __init__(self, embedding="mofanke/acge_text_embedding:latest", persist_directory="./Chroma_db/"):
9
+
10
+ self.embedding = OllamaEmbeddings(model=embedding)
11
+ self.persist_directory = persist_directory
12
+ self.chromadb = Chroma(persist_directory=persist_directory)
13
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=50)
14
+
15
+ def parse_data(self,file):
16
+ if "txt" in file.lower() or "csv" in file.lower():
17
+ try:
18
+ loaders = UnstructuredCSVLoader(file)
19
+ data = loaders.load()
20
+ except:
21
+ loaders = TextLoader(file,encoding="utf-8")
22
+ data = loaders.load()
23
+ if ".doc" in file.lower() or ".docx" in file.lower():
24
+ loaders = UnstructuredWordDocumentLoader(file)
25
+ data = loaders.load()
26
+ if "pdf" in file.lower():
27
+ loaders = UnstructuredPDFLoader(file)
28
+ data = loaders.load()
29
+ if ".xlsx" in file.lower():
30
+ loaders = UnstructuredExcelLoader(file)
31
+ data = loaders.load()
32
+ if ".md" in file.lower():
33
+ loaders = UnstructuredMarkdownLoader(file)
34
+ data = loaders.load()
35
+ if "mp3" in file.lower() or "mp4" in file.lower() or "wav" in file.lower():
36
+ # 语音解析成文字
37
+ fw = get_spk_txt(file)
38
+ loaders = UnstructuredCSVLoader(fw)
39
+ data = loaders.load()
40
+ tmp = []
41
+ for i in data:
42
+ i.metadata["source"] = file
43
+ tmp.append(i)
44
+ data = tmp
45
+ return data
46
+
47
+ # 创建 新的collection 并且初始化
48
+ def create_collection(self, files, c_name,chunk_size=200, chunk_overlap=50):
49
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
50
+ print("开始创建数据库 ....")
51
+ tmps = []
52
+ for file in files:
53
+ data = self.parse_data(file)
54
+ tmps.extend(data)
55
+
56
+ splits = self.text_splitter.split_documents(tmps)
57
+
58
+ vectorstore = self.chromadb.from_documents(documents=splits, collection_name=c_name,
59
+ embedding=self.embedding, persist_directory=self.persist_directory)
60
+ print("数据块总量:", vectorstore._collection.count())
61
+
62
+ return vectorstore
63
+
64
+ # 添加 数据到已有数据库
65
+ def add_chroma(self, files, c_name,chunk_size=200, chunk_overlap=50):
66
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
67
+ print("开始添加文件...")
68
+ tmps = []
69
+ for file in files:
70
+ data = self.parse_data(file)
71
+ tmps.extend(data)
72
+
73
+ splits = self.text_splitter.split_documents(tmps)
74
+
75
+ vectorstore = Chroma(persist_directory=self.persist_directory, collection_name=c_name,
76
+ embedding_function=self.embedding)
77
+ vectorstore.add_documents(splits)
78
+ print("数据块总量:", vectorstore._collection.count())
79
+
80
+ return vectorstore
81
+
82
+ # 删除 某个collection中的 某个文件
83
+ def del_files(self, del_files_name, c_name):
84
+
85
+ vectorstore = self.chromadb._client.get_collection(c_name)
86
+ del_ids = []
87
+ vec_dict = vectorstore.get()
88
+ for id, md in zip(vec_dict["ids"], vec_dict["metadatas"]):
89
+ for dl in del_files_name:
90
+ if dl in md["source"]:
91
+ del_ids.append(id)
92
+ vectorstore.delete(ids=del_ids)
93
+ print("数据块总量:", vectorstore.count())
94
+
95
+ return vectorstore
96
+
97
+ # 删除某个 知识库 collection
98
+ def delete_collection(self, c_name):
99
+
100
+ self.chromadb._client.delete_collection(c_name)
101
+
102
+ # 获取目前所有 collection
103
+ def get_all_collections_name(self):
104
+ cl_names = []
105
+
106
+ test = self.chromadb._client.list_collections()
107
+ for i in range(len(test)):
108
+ cl_names.append(test[i].name)
109
+ return cl_names
110
+
111
+ # 获取 collection中的所有文件
112
+ def get_collcetion_content_files(self, c_name):
113
+ vectorstore = self.chromadb._client.get_collection(c_name)
114
+ c_files = []
115
+ vec_dict = vectorstore.get()
116
+ for md in vec_dict["metadatas"]:
117
+ c_files.append(md["source"])
118
+ return list(set(c_files))
119
+
120
+
121
+ # if __name__ == "__main__":
122
+ # chromadb = ChromaDB()
123
+ # c_name = "sss3"
124
+ #
125
+ # print(chromadb.get_all_collections_name())
126
+ # chromadb.create_collection(["data/���内科学.txt", "data/jl.pdf"], c_name=c_name)
127
+ # print(chromadb.get_all_collections_name())
128
+ # chromadb.add_chroma(["data/儿科学.txt"], c_name=c_name)
129
+ # print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
130
+ # chromadb.del_files(["data/肾内科学.txt"], c_name=c_name)
131
+ # print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
132
+ # print(chromadb.get_all_collections_name())
133
+ # chromadb.delete_collection(c_name=c_name)
134
+ # print(chromadb.get_all_collections_name())
embeding/elasticsearchStore.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from elasticsearch import Elasticsearch
2
+ from langchain_elasticsearch.vectorstores import ElasticsearchStore
3
+ from langchain_community.embeddings import OllamaEmbeddings
4
+ from langchain_community.document_loaders import TextLoader, UnstructuredCSVLoader, UnstructuredPDFLoader, \
5
+ UnstructuredWordDocumentLoader, UnstructuredExcelLoader, UnstructuredMarkdownLoader
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from .asr_utils import get_spk_txt
8
+ import requests
9
+
10
+
11
+ class ElsStore():
12
+ def __init__(self, embedding="mofanke/acge_text_embedding:latest", es_url="http://localhost:9200",
13
+ index_name='test_index'):
14
+ self.embedding = OllamaEmbeddings(model=embedding)
15
+ self.es_url = es_url
16
+ self.elastic_vector_search = ElasticsearchStore(
17
+ es_url=self.es_url,
18
+ index_name=index_name,
19
+ embedding=self.embedding
20
+ )
21
+
22
+ def parse_data(self, file):
23
+ if "txt" in file.lower() or "csv" in file.lower():
24
+ try:
25
+ loaders = UnstructuredCSVLoader(file)
26
+ data = loaders.load()
27
+ except:
28
+ loaders = TextLoader(file, encoding="utf-8")
29
+ data = loaders.load()
30
+ if ".doc" in file.lower() or ".docx" in file.lower():
31
+ loaders = UnstructuredWordDocumentLoader(file)
32
+ data = loaders.load()
33
+ if "pdf" in file.lower():
34
+ loaders = UnstructuredPDFLoader(file)
35
+ data = loaders.load()
36
+ if ".xlsx" in file.lower():
37
+ loaders = UnstructuredExcelLoader(file)
38
+ data = loaders.load()
39
+ if ".md" in file.lower():
40
+ loaders = UnstructuredMarkdownLoader(file)
41
+ data = loaders.load()
42
+ if "mp3" in file.lower() or "mp4" in file.lower() or "wav" in file.lower():
43
+ # 语音解析成文字
44
+ fw = get_spk_txt(file)
45
+ loaders = UnstructuredCSVLoader(fw)
46
+ data = loaders.load()
47
+ tmp = []
48
+ for i in data:
49
+ i.metadata["source"] = file
50
+ tmp.append(i)
51
+ data = tmp
52
+ return data
53
+
54
+ def get_count(self, c_name):
55
+ # 获取index-anme中的数据块数
56
+
57
+ # 初始化 Elasticsearch 客户端
58
+ es = Elasticsearch([{
59
+ 'host': self.es_url.split(":")[1][2:],
60
+ 'port': int(self.es_url.split(":")[2]),
61
+ 'scheme': 'http' # 指定使用的协议
62
+ }])
63
+
64
+ # 指定索引名称
65
+ index_name = c_name
66
+
67
+ # 获取文档总数
68
+ response = es.count(index=index_name)
69
+
70
+ # 输出文档总数
71
+ return response['count']
72
+
73
+ # 创建 新的index_name 并且初始化
74
+ def create_collection(self, files, c_name, chunk_size=200, chunk_overlap=50):
75
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
76
+ print("开始创建数据库 ....")
77
+ tmps = []
78
+ for file in files:
79
+ data = self.parse_data(file)
80
+ tmps.extend(data)
81
+
82
+ splits = self.text_splitter.split_documents(tmps)
83
+
84
+ self.elastic_vector_search = ElasticsearchStore.from_documents(
85
+ documents=splits,
86
+ embedding=self.embedding,
87
+ es_url=self.es_url,
88
+ index_name=c_name,
89
+ )
90
+
91
+ self.elastic_vector_search.client.indices.refresh(index=c_name)
92
+
93
+ print("数据块总量:", self.get_count(c_name))
94
+
95
+ return self.elastic_vector_search
96
+
97
+ # 添加 数据到已有数据库
98
+ def add_chroma(self, files, c_name, chunk_size=200, chunk_overlap=50):
99
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
100
+ print("开始添加文件...")
101
+ tmps = []
102
+ for file in files:
103
+ data = self.parse_data(file)
104
+ tmps.extend(data)
105
+
106
+ splits = self.text_splitter.split_documents(tmps)
107
+
108
+ self.elastic_vector_search = ElasticsearchStore(
109
+ es_url=self.es_url,
110
+ index_name=c_name,
111
+ embedding=self.embedding
112
+ )
113
+ self.elastic_vector_search.add_documents(splits)
114
+ self.elastic_vector_search.client.indices.refresh(index=c_name)
115
+ print("数据块总量:", self.get_count(c_name))
116
+
117
+ return self.elastic_vector_search
118
+
119
+ # 删除某个 知识库 collection
120
+ def delete_collection(self, c_name):
121
+ url = self.es_url + "/" + c_name
122
+ # 发送 DELETE 请求
123
+ response = requests.delete(url)
124
+
125
+ # 检查响应状态码
126
+ if response.status_code == 200:
127
+ return f"索引 'test-basic1' 已成功删除。"
128
+ elif response.status_code == 404:
129
+ return f"索引 'test-basic1' 不存在。"
130
+ else:
131
+ return f"删除索引时出错: {response.status_code}, {response.text}"
132
+
133
+ # 获取目前所有 index_names
134
+ def get_all_collections_name(self):
135
+ indices = self.elastic_vector_search.client.indices.get_alias()
136
+ index_names = list(indices.keys())
137
+
138
+ return index_names
139
+
140
+ def get_collcetion_content_files(self,c_name):
141
+ return []
142
+
143
+ # 删除 某个collection中的 某个文件
144
+ def del_files(self, del_files_name, c_name):
145
+ return None
146
+
147
+
embeding/faissdb.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import FAISS
2
+ from langchain_community.embeddings import OllamaEmbeddings
3
+ from langchain_community.document_loaders import TextLoader,UnstructuredCSVLoader, UnstructuredPDFLoader,UnstructuredWordDocumentLoader,UnstructuredExcelLoader,UnstructuredMarkdownLoader
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ import shutil
6
+ import os
7
+ from .asr_utils import get_spk_txt
8
+
9
+
10
+ class FaissDB():
11
+ def __init__(self, embedding="mofanke/acge_text_embedding:latest", persist_directory="./Faiss_db/"):
12
+
13
+ self.embedding = OllamaEmbeddings(model=embedding)
14
+ self.persist_directory = persist_directory
15
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50, add_start_index=True)
16
+
17
+ def parse_data(self,file):
18
+ if "txt" in file.lower() or "csv" in file.lower():
19
+ try:
20
+ loaders = UnstructuredCSVLoader(file)
21
+ data = loaders.load()
22
+ except:
23
+ loaders = TextLoader(file,encoding="utf-8")
24
+ data = loaders.load()
25
+ if ".doc" in file.lower() or ".docx" in file.lower():
26
+ loaders = UnstructuredWordDocumentLoader(file)
27
+ data = loaders.load()
28
+ if "pdf" in file.lower():
29
+ loaders = UnstructuredPDFLoader(file)
30
+ data = loaders.load()
31
+ if ".xlsx" in file.lower():
32
+ loaders = UnstructuredExcelLoader(file)
33
+ data = loaders.load()
34
+ if ".md" in file.lower():
35
+ loaders = UnstructuredMarkdownLoader(file)
36
+ data = loaders.load()
37
+ if "mp3" in file.lower() or "mp4" in file.lower() or "wav" in file.lower():
38
+ # 语音解析成文字
39
+ fw = get_spk_txt(file)
40
+ loaders = UnstructuredCSVLoader(fw)
41
+ data = loaders.load()
42
+ tmp = []
43
+ for i in data:
44
+ i.metadata["source"] = file
45
+ tmp.append(i)
46
+ data = tmp
47
+ return data
48
+
49
+ # 创建 新的collection 并且初始化
50
+ def create_collection(self, files, c_name,chunk_size=200, chunk_overlap=50):
51
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
52
+ print("开始创建数据库 ....")
53
+ tmps = []
54
+ for file in files:
55
+ data = self.parse_data(file)
56
+ tmps.extend(data)
57
+
58
+ splits = self.text_splitter.split_documents(tmps)
59
+
60
+ vectorstore = FAISS.from_documents(documents=splits,
61
+ embedding=self.embedding)
62
+ vectorstore.save_local(self.persist_directory + c_name)
63
+ print("数据块总量:", vectorstore.index.ntotal)
64
+
65
+ return vectorstore
66
+
67
+ # 添加 数据到已有数据库
68
+ def add_chroma(self, files, c_name,chunk_size=200, chunk_overlap=50):
69
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
70
+ print("开始添加文件...")
71
+ tmps = []
72
+ for file in files:
73
+ data = self.parse_data(file)
74
+ tmps.extend(data)
75
+
76
+ splits = self.text_splitter.split_documents(tmps)
77
+
78
+ vectorstore = FAISS.load_local(folder_path=self.persist_directory + c_name, embeddings=self.embedding,
79
+ allow_dangerous_deserialization=True)
80
+ vectorstore.add_documents(documents=splits)
81
+ vectorstore.save_local("Faiss_db/" + c_name)
82
+ print("数据块总量:", vectorstore.index.ntotal)
83
+
84
+ return vectorstore
85
+
86
+ # 删除 某个collection中的 某个文件
87
+ def del_files(self, del_files_name, c_name):
88
+
89
+ vectorstore = FAISS.load_local(folder_path=self.persist_directory + c_name, embeddings=self.embedding,
90
+ allow_dangerous_deserialization=True)
91
+ del_ids = []
92
+ vec_dict = vectorstore.docstore._dict
93
+ for id, md in vec_dict.items():
94
+ for dl in del_files_name:
95
+ if dl in md.metadata["source"]:
96
+ del_ids.append(id)
97
+ vectorstore.delete(ids=del_ids)
98
+ vectorstore.save_local(self.persist_directory + c_name)
99
+ print("数据块总量:", vectorstore.index.ntotal)
100
+
101
+ return vectorstore
102
+
103
+ # 删除某个 知识库 collection
104
+ def delete_collection(self, c_name):
105
+ shutil.rmtree(self.persist_directory + c_name)
106
+
107
+ # 获取目前所有 collection
108
+ def get_all_collections_name(self):
109
+ cl_names = [i for i in os.listdir(self.persist_directory) if os.path.isdir(self.persist_directory+i)]
110
+
111
+ return cl_names
112
+
113
+ # 获取 collection中的所有文件
114
+ def get_collcetion_content_files(self, c_name):
115
+ vectorstore = FAISS.load_local(folder_path=self.persist_directory + c_name, embeddings=self.embedding,
116
+ allow_dangerous_deserialization=True)
117
+ c_files = []
118
+ vec_dict = vectorstore.docstore._dict
119
+ for _, md in vec_dict.items():
120
+ c_files.append(md.metadata["source"])
121
+
122
+ return list(set(c_files))
123
+
124
+
125
+ # if __name__ == "__main__":
126
+ # chromadb = FaissDB()
127
+ # c_name = "sss3"
128
+ #
129
+ # print(chromadb.get_all_collections_name())
130
+ # chromadb.create_collection(["data/jl.txt", "data/jl.pdf"], c_name=c_name)
131
+ # print(chromadb.get_all_collections_name())
132
+ # chromadb.add_chroma(["data/tmp.txt"], c_name=c_name)
133
+ # print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
134
+ # chromadb.del_files(["data/tmp.txt"], c_name=c_name)
135
+ # print(c_name, "包含的文件:", chromadb.get_collcetion_content_files(c_name))
136
+ # print(chromadb.get_all_collections_name())
137
+ # chromadb.delete_collection(c_name=c_name)
138
+ # print(chromadb.get_all_collections_name())
embeding/tmp.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ "context"
2
+ 你是不是觉得自己说话的声音直来直去呢?现在告诉你一个主持人吐字的小秘密,那就是每个字在口腔当中像是翻跟头一样打一圈再出来。比如说故人西辞黄鹤楼,而不是故人西辞黄鹤楼。再比如说乌衣巷口夕阳斜,而不是乌衣巷口夕阳斜,你也试试看抖音。
img/graph-tool.png ADDED

Git LFS Details

  • SHA256: 8aa90d4cba907a57c8d5cc5e2c193240955c1c9cee23dcffc8bbd597616f6bed
  • Pointer size: 131 Bytes
  • Size of remote file: 207 kB
img/readme.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 1
img/zhu.png ADDED

Git LFS Details

  • SHA256: 71f5efce94f123211ea1b7392e9644f953ae158ee10a3c4d28f522a23b9387b8
  • Pointer size: 131 Bytes
  • Size of remote file: 630 kB
img/zhuye.png ADDED
img//345/244/215/346/235/202/346/226/271/345/274/217.png ADDED
img//345/276/256/344/277/241/345/233/276/347/211/207_20240524180648.jpg ADDED
rag/__init__.py ADDED
File without changes
rag/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (132 Bytes). View file
 
rag/__pycache__/config.cpython-310.pyc ADDED
Binary file (364 Bytes). View file
 
rag/__pycache__/rag_class.cpython-310.pyc ADDED
Binary file (5.39 kB). View file
 
rag/__pycache__/rerank.cpython-310.pyc ADDED
Binary file (878 Bytes). View file
 
rag/__pycache__/rerank.cpython-39.pyc ADDED
Binary file (869 Bytes). View file
 
rag/__pycache__/rerank_code.cpython-310.pyc ADDED
Binary file (883 Bytes). View file
 
rag/rag_class.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import Chroma,FAISS
2
+ from langchain_community.llms import Ollama
3
+ from langchain_core.output_parsers import StrOutputParser
4
+ from langchain_community.embeddings import OllamaEmbeddings
5
+ from langchain_core.runnables import RunnablePassthrough
6
+ from operator import itemgetter
7
+ from langchain.prompts import ChatPromptTemplate
8
+ from rerank_code import rerank_topn
9
+ from Config.config import VECTOR_DB,DB_directory
10
+ from langchain_elasticsearch.vectorstores import ElasticsearchStore
11
+
12
+
13
+ class RAG_class:
14
+ def __init__(self, model="qwen2:7b", embed="milkey/dmeta-embedding-zh:f16", c_name="sss1",
15
+ persist_directory="E:/pycode/jupyter_code/langGraph/sss2/chroma.sqlite3/",es_url="http://localhost:9200"):
16
+ template = """
17
+ 根据上下文回答以下问题,不要自己发挥,要根据以下参考内容总结答案,如果以下内容无法得到答案,就返回无法根据参考内容获取答案,
18
+
19
+ 参考内容为:{context}
20
+
21
+ 问题: {question}
22
+ """
23
+
24
+ self.prompts = ChatPromptTemplate.from_template(template)
25
+
26
+ # 使用 问题扩展+结果递归方式得到最终答案
27
+ template1 = """你是一个乐于助人的助手,可以生成与输入问题相关的多个子问题。
28
+ 目标是将输入分解为一组可以单独回答的子问题/子问题。
29
+ 生成多个与以下内容相关的搜索查询:{question}
30
+ 输出4个相关问题,以换行符隔开:"""
31
+ self.prompt_questions = ChatPromptTemplate.from_template(template1)
32
+
33
+ # 构建 问答对
34
+ template2 = """
35
+ 以下是您需要回答的问题:
36
+
37
+ \n--\n {question} \n---\n
38
+
39
+ 以下是任何可用的背景问答对:
40
+
41
+ \n--\n {q_a_pairs} \n---\n
42
+
43
+ 以下是与该问题相关的其他上下文:
44
+
45
+ \n--\n {context} \n---\n
46
+
47
+ 使用以上上下文和背景问答对来回答问题,问题是:{question} ,答案是:
48
+ """
49
+ self.decomposition_prompt = ChatPromptTemplate.from_template(template2)
50
+
51
+ self.llm = Ollama(model=model)
52
+ self.embeding = OllamaEmbeddings(model=embed)
53
+ if VECTOR_DB==1:
54
+ self.vectstore = Chroma(embedding_function=self.embeding, collection_name=c_name,
55
+ persist_directory=persist_directory)
56
+ elif VECTOR_DB ==2:
57
+ self.vectstore = FAISS.load_local(folder_path=persist_directory + c_name, embeddings=self.embeding,
58
+ allow_dangerous_deserialization=True)
59
+ elif VECTOR_DB ==3:
60
+ self.vectstore = ElasticsearchStore(
61
+ es_url=es_url,
62
+ index_name=c_name,
63
+ embedding=self.embeding
64
+ )
65
+ self.retriever = self.vectstore.as_retriever()
66
+ try:
67
+ if VECTOR_DB==1:
68
+ self.vectstore = Chroma(embedding_function=self.embeding, collection_name=c_name,
69
+ persist_directory=persist_directory)
70
+ elif VECTOR_DB ==2:
71
+ self.vectstore = FAISS.load_local(folder_path=persist_directory + c_name, embeddings=self.embeding,
72
+ allow_dangerous_deserialization=True)
73
+ elif VECTOR_DB ==3:
74
+ self.vectstore = ElasticsearchStore(
75
+ es_url=es_url,
76
+ index_name=c_name,
77
+ embedding=self.embeding
78
+ )
79
+ self.retriever = self.vectstore.as_retriever()
80
+ except Exception as e:
81
+ print("仅模型时无需加载数据库",e)
82
+ #
83
+ # Post-processing
84
+ def format_docs(self,docs):
85
+ return "\n\n".join(doc.page_content for doc in docs)
86
+ # 传统方式召回,单问题召回,然后llm总结答案回答
87
+ def simple_chain(self,question):
88
+ _chain = (
89
+ {"context": self.retriever|self.format_docs,"question":RunnablePassthrough()}
90
+ |self.prompts
91
+ |self.llm
92
+ |StrOutputParser()
93
+ )
94
+ answer = _chain.invoke({"question":question})
95
+ return answer
96
+
97
+ def rerank_chain(self,question):
98
+ retriever = self.vectstore.as_retriever(search_kwargs={"k": 10})
99
+ docs = retriever.invoke(question)
100
+ docs = rerank_topn(question,docs,N=5)
101
+ _chain = (
102
+ self.prompts
103
+ | self.llm
104
+ | StrOutputParser()
105
+ )
106
+ answer = _chain.invoke({"context":self.format_docs(docs),"question": question})
107
+ return answer
108
+
109
+ def format_qa_pairs(self, question, answer):
110
+ formatted_string = ""
111
+ formatted_string += f"Question: {question}\nAnswer:{answer}\n\n"
112
+ return formatted_string
113
+
114
+ # 获取问题的 扩展问题
115
+ def decomposition_chain(self, question):
116
+ _chain = (
117
+ {"question": RunnablePassthrough()}
118
+ | self.prompt_questions
119
+ | self.llm
120
+ | StrOutputParser()
121
+ | (lambda x: x.split("\n"))
122
+ )
123
+
124
+ questions = _chain.invoke({"question": question}) + [question]
125
+
126
+ return questions
127
+ # 多问题递归召回,每次召回后,问题和答案同时作为下一次召回的参考,再次用新问题召回
128
+ def rag_chain(self, questions):
129
+ q_a_pairs = ""
130
+ for q in questions:
131
+ _chain = (
132
+ {"context": itemgetter("question") | self.retriever,
133
+ "question": itemgetter("question"),
134
+ "q_a_pairs": itemgetter("q_a_paris")
135
+ }
136
+ | self.decomposition_prompt
137
+ | self.llm
138
+ | StrOutputParser()
139
+ )
140
+
141
+ answer = _chain.invoke({"question": q, "q_a_paris": q_a_pairs})
142
+ q_a_pairs = self.format_qa_pairs(q, answer)
143
+ q_a_pairs = q_a_pairs + "\n----\n" + q_a_pairs
144
+ return answer
145
+
146
+ # 将聊天历史格式化为一个字符串
147
+ def format_chat_history(self,history):
148
+ formatted_history = ""
149
+ for role,content in history:
150
+ formatted_history += f"{role}: {content}\n"
151
+ return formatted_history
152
+ # 基于ollama大模型的大模型 多轮对话,不使用知识库的
153
+ def mult_chat(self,chat_history):
154
+ # 格式化聊天历史
155
+ formatted_history = self.format_chat_history(chat_history)
156
+
157
+ # 调用模型生成回复
158
+ response = self.llm.invoke(formatted_history)
159
+ return response
160
+
161
+
162
+
163
+ # if __name__ == "__main__":
164
+ # rag = RAG_class(model="deepseek-r1:14b")
165
+ # question = "人卫社官网网址是?"
166
+ # questions = rag.decomposition_chain(question)
167
+ # print(questions)
168
+ # answer = rag.rag_chain(questions)
169
+ # print(answer)
rag/rerank_code.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+
4
+ tokenizer = AutoTokenizer.from_pretrained('E:\\model\\bge-reranker-large')
5
+ model = AutoModelForSequenceClassification.from_pretrained('E:\\model\\bge-reranker-large')
6
+ model.eval()
7
+
8
+
9
+ def rerank_topn(question,docs,N=5):
10
+ pairs = []
11
+ for i in docs:
12
+ pairs.append([question,i.page_content])
13
+
14
+ with torch.no_grad():
15
+ inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
16
+ scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
17
+ scores = scores.argsort().numpy()[::-1][:N]
18
+ bk = []
19
+ for i in scores:
20
+ bk.append(docs[i])
21
+ return bk
test/__init__.py ADDED
File without changes
test/graph2neo4j.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import sys
3
+ sys.path.append(r"..//")#
4
+ from Neo4j.neo4j_op import KnowledgeGraph
5
+ from Neo4j.graph_extract import update_graph
6
+ from Config.config import neo4j_host,neo4j_name,neo4j_pwd
7
+
8
+
9
+
10
+ kg = KnowledgeGraph(neo4j_host,neo4j_name,neo4j_pwd)
11
+
12
+
13
+ if __name__ == "__main__":
14
+
15
+ text = """范冰冰,1981年9月16日生于山东青岛,毕业于上海师范大学谢晋影视艺术学院,中国女演员,歌手。
16
+ 1998年参演电视剧《还珠格格》成名。2004年主演电影《手机》获得第27届大众电影百花奖最佳女演员奖。"""
17
+ res = update_graph(text)
18
+ # 批量创建节点
19
+ nodes = kg.create_nodes("node", res["nodes"])
20
+ print(nodes)
21
+ # 批量创建关系
22
+ relationships = kg.create_relationships([
23
+ ("node", {"name": edge["source"]}, "node", {"name": edge["target"]}, edge["label"]) for edge in res["edges"]
24
+ ])
25
+ print(relationships)