bytedancerneat commited on
Commit
1de832e
·
verified ·
1 Parent(s): b263d05

Update retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +120 -120
retriever.py CHANGED
@@ -1,121 +1,121 @@
1
- import pandas as pd
2
- import json
3
- import sys
4
- import os
5
- from collections import defaultdict
6
- from util.vector_base import EmbeddingFunction, get_or_create_vector_base
7
- from util.Embeddings import TextEmb3LargeEmbedding
8
- from langchain_core.documents import Document
9
- from FlagEmbedding import FlagReranker
10
- import time
11
- from bm25s import BM25, tokenize
12
- import contextlib
13
- import io
14
- from tqdm import tqdm
15
-
16
- def rrf(rankings, k = 60):
17
- res = 0
18
- for r in rankings:
19
- res += 1 / (r + k)
20
- return res
21
-
22
- def retriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=False, using_BM25=False, using_chroma=True, k=20, if_split_po=True):
23
- final_result = []
24
- if not if_split_po:
25
- final_result = multiretriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=using_reranker, using_BM25=using_BM25, using_chroma=using_chroma, k=k)
26
- else:
27
- for po in PO:
28
- po_result = multiretriever(requirement, [po], safeguard_vector_store, reranker_model, using_reranker=using_reranker, using_BM25=using_BM25, using_chroma=using_chroma, k=k)
29
- for safeguard in po_result:
30
- final_result.append(safeguard)
31
- return final_result
32
-
33
- def multiretriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=True, using_BM25=False, using_chroma=True, k=20):
34
- """
35
- requirements_dict: [
36
- requirement: {
37
- "PO": [],
38
- "safeguard": []
39
- }
40
- ]
41
- """
42
- candidate_safeguards = []
43
- po_list = [po.lower().rstrip() for po in PO if po]
44
- if "young users" in po_list and len(po_list) == 1:
45
- return []
46
- candidate_safeguards = safeguard_vector_store.get(where={"po": {"$in": po_list}})
47
- safeguard_dict, safeguard_content = {}, []
48
- for id, content, metadata in zip(candidate_safeguards['ids'], candidate_safeguards['documents'], candidate_safeguards['metadatas']):
49
- safeguard_dict[content] = {
50
- "metadata": metadata,
51
- "rank": [],
52
- "rrf_score": 0
53
- }
54
- safeguard_content.append(content)
55
-
56
- # Reranker
57
- if using_reranker:
58
- content_pairs, reranking_rank, reranking_results = [], [], []
59
- for safeguard in safeguard_content:
60
- content_pairs.append([requirement, safeguard])
61
- safeguard_rerank_scores = reranker_model.compute_score(content_pairs)
62
- for content_pair, score in zip(content_pairs, safeguard_rerank_scores):
63
- reranking_rank.append((content_pair[1], score))
64
- reranking_results = sorted(reranking_rank, key=lambda x: x[1], reverse=True)
65
- for safeguard, score in reranking_results:
66
- safeguard_dict[safeguard]['rank'].append(reranking_results.index((safeguard, score)) + 1)
67
-
68
- # BM25
69
- if using_BM25:
70
- with contextlib.redirect_stdout(io.StringIO()):
71
- bm25_retriever = BM25(corpus=safeguard_content)
72
- bm25_retriever.index(tokenize(safeguard_content))
73
- bm25_results, scores = bm25_retriever.retrieve(tokenize(requirement), k = len(safeguard_content))
74
- bm25_retrieval_rank = 1
75
- for safeguard in bm25_results[0]:
76
- safeguard_dict[safeguard]['rank'].append(bm25_retrieval_rank)
77
- bm25_retrieval_rank += 1
78
-
79
- # chroma retrieval
80
- if using_chroma:
81
- retrieved_safeguards = safeguard_vector_store.similarity_search_with_score(query=requirement, k=len(candidate_safeguards['ids']), filter={"po": {"$in": po_list}})
82
- retrieval_rank = 1
83
- for safeguard in retrieved_safeguards:
84
- safeguard_dict[safeguard[0].page_content]['rank'].append(retrieval_rank)
85
- retrieval_rank += 1
86
-
87
- final_result = []
88
- for safeguard in safeguard_content:
89
- safeguard_dict[safeguard]['rrf_score'] = rrf(safeguard_dict[safeguard]['rank'])
90
- final_result.append((safeguard_dict[safeguard]['rrf_score'], safeguard_dict[safeguard]['metadata']['safeguard_number'], safeguard, safeguard_dict[safeguard]['metadata']['po']))
91
- final_result.sort(key=lambda x: x[0], reverse=True)
92
-
93
- # top k
94
- topk_final_result = final_result[:k]
95
-
96
- return topk_final_result
97
-
98
- if __name__=="__main__":
99
- embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58)
100
- embedding = EmbeddingFunction(embeddingmodel)
101
- safeguard_vector_store = get_or_create_vector_base('safeguard_database', embedding)
102
- reranker_model = FlagReranker(
103
- '/root/PTR-LLM/tasks/pcf/model/bge-reranker-v2-m3',
104
- use_fp16=True,
105
- devices=["cpu"],
106
- )
107
- requirement = """
108
- Data Minimization Consent for incompatible purposes: Require consent for additional use of personal information not reasonably necessary to or incompatible with original purpose disclosure.
109
- """
110
- PO = ["Data Minimization & Purpose Limitation", "Transparency"]
111
- final_result = retriever(
112
- requirement,
113
- PO,
114
- safeguard_vector_store,
115
- reranker_model,
116
- using_reranker=True,
117
- using_BM25=False,
118
- using_chroma=True,
119
- k=10
120
- )
121
  print(final_result)
 
1
+ import pandas as pd
2
+ import json
3
+ import sys
4
+ import os
5
+ from collections import defaultdict
6
+ from util.vector_base import EmbeddingFunction, get_or_create_vector_base
7
+ from util.Embeddings import TextEmb3LargeEmbedding
8
+ from langchain_core.documents import Document
9
+ from FlagEmbedding import FlagReranker
10
+ import time
11
+ # from bm25s import BM25, tokenize
12
+ import contextlib
13
+ import io
14
+ from tqdm import tqdm
15
+
16
+ def rrf(rankings, k = 60):
17
+ res = 0
18
+ for r in rankings:
19
+ res += 1 / (r + k)
20
+ return res
21
+
22
+ def retriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=False, using_BM25=False, using_chroma=True, k=20, if_split_po=True):
23
+ final_result = []
24
+ if not if_split_po:
25
+ final_result = multiretriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=using_reranker, using_BM25=using_BM25, using_chroma=using_chroma, k=k)
26
+ else:
27
+ for po in PO:
28
+ po_result = multiretriever(requirement, [po], safeguard_vector_store, reranker_model, using_reranker=using_reranker, using_BM25=using_BM25, using_chroma=using_chroma, k=k)
29
+ for safeguard in po_result:
30
+ final_result.append(safeguard)
31
+ return final_result
32
+
33
+ def multiretriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=True, using_BM25=False, using_chroma=True, k=20):
34
+ """
35
+ requirements_dict: [
36
+ requirement: {
37
+ "PO": [],
38
+ "safeguard": []
39
+ }
40
+ ]
41
+ """
42
+ candidate_safeguards = []
43
+ po_list = [po.lower().rstrip() for po in PO if po]
44
+ if "young users" in po_list and len(po_list) == 1:
45
+ return []
46
+ candidate_safeguards = safeguard_vector_store.get(where={"po": {"$in": po_list}})
47
+ safeguard_dict, safeguard_content = {}, []
48
+ for id, content, metadata in zip(candidate_safeguards['ids'], candidate_safeguards['documents'], candidate_safeguards['metadatas']):
49
+ safeguard_dict[content] = {
50
+ "metadata": metadata,
51
+ "rank": [],
52
+ "rrf_score": 0
53
+ }
54
+ safeguard_content.append(content)
55
+
56
+ # Reranker
57
+ if using_reranker:
58
+ content_pairs, reranking_rank, reranking_results = [], [], []
59
+ for safeguard in safeguard_content:
60
+ content_pairs.append([requirement, safeguard])
61
+ safeguard_rerank_scores = reranker_model.compute_score(content_pairs)
62
+ for content_pair, score in zip(content_pairs, safeguard_rerank_scores):
63
+ reranking_rank.append((content_pair[1], score))
64
+ reranking_results = sorted(reranking_rank, key=lambda x: x[1], reverse=True)
65
+ for safeguard, score in reranking_results:
66
+ safeguard_dict[safeguard]['rank'].append(reranking_results.index((safeguard, score)) + 1)
67
+
68
+ # BM25
69
+ if using_BM25:
70
+ with contextlib.redirect_stdout(io.StringIO()):
71
+ bm25_retriever = BM25(corpus=safeguard_content)
72
+ bm25_retriever.index(tokenize(safeguard_content))
73
+ bm25_results, scores = bm25_retriever.retrieve(tokenize(requirement), k = len(safeguard_content))
74
+ bm25_retrieval_rank = 1
75
+ for safeguard in bm25_results[0]:
76
+ safeguard_dict[safeguard]['rank'].append(bm25_retrieval_rank)
77
+ bm25_retrieval_rank += 1
78
+
79
+ # chroma retrieval
80
+ if using_chroma:
81
+ retrieved_safeguards = safeguard_vector_store.similarity_search_with_score(query=requirement, k=len(candidate_safeguards['ids']), filter={"po": {"$in": po_list}})
82
+ retrieval_rank = 1
83
+ for safeguard in retrieved_safeguards:
84
+ safeguard_dict[safeguard[0].page_content]['rank'].append(retrieval_rank)
85
+ retrieval_rank += 1
86
+
87
+ final_result = []
88
+ for safeguard in safeguard_content:
89
+ safeguard_dict[safeguard]['rrf_score'] = rrf(safeguard_dict[safeguard]['rank'])
90
+ final_result.append((safeguard_dict[safeguard]['rrf_score'], safeguard_dict[safeguard]['metadata']['safeguard_number'], safeguard, safeguard_dict[safeguard]['metadata']['po']))
91
+ final_result.sort(key=lambda x: x[0], reverse=True)
92
+
93
+ # top k
94
+ topk_final_result = final_result[:k]
95
+
96
+ return topk_final_result
97
+
98
+ if __name__=="__main__":
99
+ embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58)
100
+ embedding = EmbeddingFunction(embeddingmodel)
101
+ safeguard_vector_store = get_or_create_vector_base('safeguard_database', embedding)
102
+ reranker_model = FlagReranker(
103
+ '/root/PTR-LLM/tasks/pcf/model/bge-reranker-v2-m3',
104
+ use_fp16=True,
105
+ devices=["cpu"],
106
+ )
107
+ requirement = """
108
+ Data Minimization Consent for incompatible purposes: Require consent for additional use of personal information not reasonably necessary to or incompatible with original purpose disclosure.
109
+ """
110
+ PO = ["Data Minimization & Purpose Limitation", "Transparency"]
111
+ final_result = retriever(
112
+ requirement,
113
+ PO,
114
+ safeguard_vector_store,
115
+ reranker_model,
116
+ using_reranker=True,
117
+ using_BM25=False,
118
+ using_chroma=True,
119
+ k=10
120
+ )
121
  print(final_result)