File size: 12,284 Bytes
d8f227d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import json
import random

import pandas as pd
import torch
import os

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
from sentence_transformers import SentenceTransformer
import tqdm
import numpy as np
import faiss
from sklearn.metrics import ndcg_score
from os.path import join
from sklearn.preprocessing import normalize
from transformers import AutoTokenizer, AutoModel

faiss.omp_set_num_threads(16)


def find_topk_by_vecs(source_vecs: np.ndarray, target_vecs: np.ndarray, topk: int):
    if topk > len(target_vecs):
        topk = len(target_vecs)
    faiss_index = faiss.IndexFlatIP(target_vecs.shape[1])
    faiss_index.add(target_vecs)

    res_distance, res_index = faiss_index.search(source_vecs, topk)
    return res_index, res_distance


def get_loco_path_info(q_dir, d_dir):
    names = []
    for name in sorted(os.listdir(q_dir)):
        if name.endswith(".jsonl"):
            names.append(name)
    for name in os.listdir(d_dir):
        if name.endswith(".jsonl"):
            assert name in names
    infos = []
    for name in names:
        infos.append(["LOCO-V1", name, join(q_dir, name), join(d_dir, name)])
    infos.sort(key=lambda x: x[1])
    return infos


def get_loco_data(q_path, d_path):
    passage_list, query2passage_list = [], {}

    original_doc_id2doc = {}

    with open(d_path, "r", encoding="utf8") as fr:
        for line in fr:
            item = json.loads(line)
            if item["passage"].strip():
                original_doc_id2doc[item["pid"]] = item["passage"].strip()
                passage_list.append(item["passage"].strip())

    with open(q_path, "r", encoding="utf8") as fr:
        for line in fr:
            item = json.loads(line)
            if item["query"].strip():
                query2passage_list[item["query"].strip()] = [
                    original_doc_id2doc[answer_pid]
                    for answer_pid in item["answer_pids"]
                    if answer_pid in original_doc_id2doc
                ]
    query2passage_list = {k: list(set(v)) for k, v in query2passage_list.items() if list(set(v))}
    passage_list = list(set(passage_list))
    passage2id = {passage: idx for idx, passage in enumerate(passage_list)}
    query2id_list = {k: list(set([passage2id[i] for i in v])) for k, v in query2passage_list.items()}
    query_list = list(query2id_list.keys())
    return query_list, passage_list, query2id_list


def get_ndcg_score(query_list, passage_list, query2passage_id_list, topk=10, error_data_save_path: str = None):
    chunk_id2passage_id = {}
    q_vecs = model.encode(
        sentences=query_list,
        batch_size=batch_size,
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        max_seq_length=max_seq_length,
        is_q=True,
    )
    p_vecs = model.encode(
        sentences=passage_list,
        batch_size=batch_size,
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        max_seq_length=max_seq_length,
        is_q=False,
    )
    # according query2id_list get labels_list
    query_id_list = [query2passage_id_list[query] for query in query_list]
    max_doc = max((len(id_list) for id_list in query_id_list))

    labels = np.array([(id_list * max_doc)[:max_doc] for id_list in query_id_list])
    if isinstance(p_vecs, list):
        for idx, vec in enumerate(p_vecs):
            if multi_vec_strategy == "full_text":
                p_vecs[idx] = normalize(np.mean(vec[1:2, :], axis=0, keepdims=True), axis=1)
            elif multi_vec_strategy == "full_text+chunks":
                n_chunk = (vec.shape[0] - 2) // 2
                if n_chunk > 0:
                    p_vecs[idx] = np.vstack(
                        (
                            normalize(np.mean(vec[:2, :], axis=0, keepdims=True), axis=1),
                            vec[2:2 + n_chunk, :],
                        )
                    )
                else:
                    p_vecs[idx] = normalize(np.mean(vec[:2, :], axis=0, keepdims=True), axis=1)
        p_vecs = np.vstack(p_vecs)

    if isinstance(q_vecs, list):
        for idx, vec in enumerate(q_vecs):
            q_vecs[idx] = normalize(np.mean(vec[0:2, :], axis=0, keepdims=True), axis=1)
        q_vecs = np.vstack(q_vecs)
    print("q_vecs.shape and dtype", q_vecs.shape, q_vecs.dtype)
    print("p_vecs.shape and dtype", p_vecs.shape, p_vecs.dtype)
    # search topk
    # we calculate ndcg@10
    topk_index, topk_scores = find_topk_by_vecs(q_vecs, p_vecs, topk * 100)
    # print("topk_index", topk_index.shape, topk_index)
    # print("topk_scores", topk_scores.shape, topk_scores)
    ### we may use multi vectors, so we should modify topk_index and topk_scores
    if chunk_id2passage_id:
        new_topk_index, new_topk_scores = [], []
        # print("chunk_id2passage_id")
        for chunk_ids, chunk_scores in tqdm.tqdm(zip(topk_index, topk_scores),
                                                 desc="modify topk_index and topk_scores", disable=True):
            # processed by row
            row_ids, row_scores, passage_id_set = [], [], set()
            for idx, chunk_id in enumerate(chunk_ids):
                passage_id = chunk_id2passage_id[chunk_id]
                if passage_id not in passage_id_set:
                    passage_id_set.add(passage_id)
                    row_ids.append(passage_id)
                    row_scores.append(chunk_scores[idx])
            new_topk_index.append(row_ids[:topk])
            new_topk_scores.append(row_scores[:topk])
        topk_index = np.array(new_topk_index)
        # print("topk_index", topk_index)
        topk_scores = np.array(new_topk_scores)
    topk_index, topk_scores = topk_index[:, :topk], topk_scores[:, :topk]
    is_match = (topk_index == labels[:, :1])
    for idx in range(1, max_doc):
        # the or operator means that only one positive doc in pred topk, we think it is recalled
        is_match = is_match | (topk_index == labels[:, idx:idx + 1])

    # compute recall at topk
    print("is_match.shape", is_match.shape)
    # recall_at_k = is_match.sum(axis=1).astype(bool).mean()
    ndcg = ndcg_score(is_match.astype(dtype=np.float32), topk_scores)

    if error_data_save_path:
        in_top_k = is_match.sum(axis=1).astype(bool)
        err_data = []
        for idx, pred_res in enumerate(in_top_k):
            if not pred_res:
                query = query_list[idx]
                label_doc = passage_list[query2passage_id_list[query][0]]
                pred_doc = passage_list[topk_index[idx][0]]
                err_data.append([query, label_doc, pred_doc])
        pd.DataFrame(err_data, columns=["Query", "Label", "Pred"]).to_excel(error_data_save_path, index=False)
    return float(ndcg)


class ModelWrapper:
    def __init__(self, model_dir, model_type, max_seq_length):
        assert model_type in ["dewey", "sentence_transformer"]
        self.model_type = model_type
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
        if model_type == "dewey":
            self.model = AutoModel.from_pretrained(
                model_dir,
                attn_implementation="flash_attention_2",
                trust_remote_code=True,
            ).cuda().bfloat16().eval()
            self.model.tokenizer = self.tokenizer
        else:
            self.model = SentenceTransformer(
                model_dir,
                trust_remote_code=True,
                device="cpu",
                model_kwargs={
                    "torch_dtype": torch.bfloat16,  # fp16
                    "attn_implementation": "flash_attention_2"
                },
            )
            self.model.max_seq_length = max_seq_length
            if "NV-Embed-v2" in model_dir:
                self.model.tokenizer.padding_side = "right"
            self.pool = self.model.start_multi_process_pool()

    def encode(

            self,

            sentences,

            batch_size,

            chunk_size,

            chunk_overlap,

            max_seq_length,

            is_q,

    ):
        if self.model_type == "dewey":
            if is_q:
                prompt = "<|START_INSTRUCTION|>Answer the question<|END_INSTRUCTION|>"
            else:
                prompt = "<|START_INSTRUCTION|>Candidate document<|END_INSTRUCTION|>"
            return self.model.encode(
                sentences=sentences,
                batch_size=batch_size,
                use_cuda=True,
                show_progress_bar=True,
                chunk_size=chunk_size,
                chunk_overlap=chunk_overlap,
                convert_to_tensor=False,
                max_seq_length=max_seq_length,
                normalize_embeddings=True,
                prompt=prompt,
                fast_chunk=True,
            )[0]
        self.model.max_seq_length = max_seq_length
        prompt = None
        if is_q and (
                "Linq-Embed-Mistral" in model_dir or "e5-mistral-7b-instruct" in model_dir or "SFR-Embedding-Mistral" in model_dir):
            prompt = PROMPT_E5
        if is_q and ("NV-Embed-v2" in model_dir):
            prompt = PROMPT_NV
        if "chunk_alignment" in model_dir or "dewey" in model_dir:
            if is_q:
                prompt = "<|START_INSTRUCTION|>Answer the question<|END_INSTRUCTION|>"
            else:
                prompt = "<|START_INSTRUCTION|>Candidate document<|END_INSTRUCTION|>"
        vecs = self.model.encode_multi_process(
            add_eos(sentences) if "NV-Embed-v2" in model_dir else sentences,
            pool=self.pool,
            show_progress_bar=True,
            batch_size=batch_size,
            normalize_embeddings=True,
            prompt=prompt
        )
        return vecs


def add_eos(input_examples):
    input_examples = [input_example + model.tokenizer.eos_token for input_example in input_examples]
    return input_examples


PROMPT_BGE = "Represent this sentence for searching relevant passages:"
PROMPT_E5 = "Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: "
PROMPT_NV = "Instruct: Given a question, retrieve passages that answer the question\nQuery: "
if __name__ == "__main__":
    chunk_size = -1
    chunk_overlap = 32
    batch_size = 2
    max_seq_length = 8 * 1024
    multi_vec_strategy = "full_text"  # full_text; full_text+chunks
    err_data_save_path = None
    topk = 10

    model_dir = "infgrad/dewey_en_beta"
    # model_dir = "/home/zd/public_models/Linq-Embed-Mistral/"
    # model_dir = "/home/zd/public_models/SFR-Embedding-Mistral"
    # model_dir = "/home/zd/public_models/e5-mistral-7b-instruct"
    # model_dir = "/home/zd/public_models/bge-m3"
    # model_dir = "/home/zd/public_models/gte-modernbert-base"
    # model_dir = "/home/zd/public_models/NV-Embed-v2"

    # sentence_transformer dewey
    model_type = "sentence_transformer"
    ## get data info
    # TODO Please download LOCOV1 data first!
    data_info = get_loco_path_info(
        "/home/zd/public_data/LoCoV1-Queries/documents/",
        "/home/zd/public_data/LoCoV1-Documents/documents/",
    )

    # load model
    model = ModelWrapper(model_dir=model_dir, model_type=model_type, max_seq_length=max_seq_length)
    # model = zd()
    ndcg_score_list = []
    for item in data_info:
        print("\n\n\n\n" + "=" * 20)
        print(f"evaluate {item[:2]}...")
        query_list, passage_list, query2passage_id_list = get_loco_data(*item[2:])
        print("number of all queries", len(query_list))
        print("number of all passages", len(passage_list))
        ndcg = get_ndcg_score(query_list, passage_list, query2passage_id_list, topk=topk,
                              error_data_save_path=err_data_save_path)
        print(f"{ndcg}")
        ndcg_score_list.append(ndcg)

    for i in data_info:
        print(i[0])
    print("\n\n\n")
    for i in data_info:
        print(i[1].replace(".jsonl", ""))
    print("\n\n\n")

    print(os.path.basename(model_dir))
    for i in ndcg_score_list:
        print(i)