File size: 5,579 Bytes
fbc1304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["OPENBLAS_NUM_THREADS"] = "32"
import numpy as np
import torch
import mteb
from mteb.encoder_interface import PromptType
from sentence_transformers import SentenceTransformer
from mteb.models.wrapper import Wrapper
from typing import Sequence
from typing import Any
from transformers import AutoTokenizer, AutoModel


class DeweySingleVectorWrapper:
    def __init__(self, model_dir, batch_size: int = 8):
        self.model = SentenceTransformer(
            model_dir,
            trust_remote_code=True,
            model_kwargs={
                "torch_dtype": torch.bfloat16,  # fp16 容易计算出nan
                "attn_implementation": "flash_attention_2"
            },
            config_kwargs={"single_vector_type": "mean"}
        ).cuda().bfloat16().eval()
        self.model.max_seq_length = max_seq_length
        self.pool = self.model.start_multi_process_pool()
        self.batch_size = batch_size

    def encode(

            self,

            sentences: list[str],

            task_name: str,

            prompt_type: PromptType | None = None,

            **kwargs,

    ) -> np.ndarray:
        if prompt_type.value == "query":
            prompt = RETRIEVE_Q_PROMPT
        else:
            prompt = RETRIEVE_P_PROMPT
        vectors = self.model.encode_multi_process(
            sentences=sentences,
            pool=self.pool,
            show_progress_bar=True,
            batch_size=self.batch_size,
            normalize_embeddings=True,
            prompt=prompt,
            precision="float32"
        )
        return vectors


class DeweyMultiVectorWrapper(Wrapper):
    def __init__(

            self,

            model_dir: str,

            batch_size: int = 8,

            *args,

            **kwargs,

    ) -> None:
        self.model = AutoModel.from_pretrained(
            model_dir,
            trust_remote_code=True,
            attn_implementation="flash_attention_2"
        ).cuda().bfloat16()
        self.batch_size = batch_size
        self.model.tokenizer = AutoTokenizer.from_pretrained(model_dir)

    def encode(

            self,

            sentences: Sequence[str],

            *,

            task_name: str,

            prompt_type: PromptType | None = None,

            **kwargs: Any,

    ) -> np.ndarray:

        if prompt_type.value == "query":
            prompt = RETRIEVE_Q_PROMPT
        else:
            prompt = RETRIEVE_P_PROMPT
        if prompt_type.value == "query":
            pred = self.model.encode(
                sentences=list(sentences),
                use_cuda=True,
                show_progress_bar=True,
                chunk_size=-1,
                chunk_overlap=32,
                convert_to_tensor=True,
                max_seq_length=max_seq_length,
                batch_size=self.batch_size,
                normalize_embeddings=True,
                prompt=prompt,
                fast_chunk=False

            )[0]
            # query vector do not need multi vector, we only use mean as final one vector
            pred = [vecs[1:2, :] for vecs in pred]
        else:
            pred = self.model.encode(
                sentences=list(sentences),
                use_cuda=True,
                show_progress_bar=True,
                chunk_size=256,
                chunk_overlap=32,
                convert_to_tensor=True,
                max_seq_length=max_seq_length,
                batch_size=self.batch_size,
                normalize_embeddings=True,
                prompt=prompt,
                fast_chunk=True,
            )[0]

        pred = torch.nn.utils.rnn.pad_sequence(pred, batch_first=True, padding_value=0)
        return pred.cpu().numpy()

    def similarity(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
        if not isinstance(a, torch.Tensor):
            a = torch.tensor(a, dtype=torch.float32)

        if not isinstance(b, torch.Tensor):
            b = torch.tensor(b, dtype=torch.float32)

        if len(a.shape) == 2:
            a = a.unsqueeze(0)

        if len(b.shape) == 2:
            b = b.unsqueeze(0)

        scores = torch.einsum(
            "ash,bth->abst",
            a,
            b,
        )

        return scores.max(axis=-1).values.sum(axis=-1)


RETRIEVE_Q_PROMPT = "<|START_INSTRUCTION|>Answer the question<|END_INSTRUCTION|>"
RETRIEVE_P_PROMPT = "<|START_INSTRUCTION|>Candidate document<|END_INSTRUCTION|>"

if __name__ == "__main__":
    #################  evaluate single vector  #################
    # batch_size = 4
    # max_seq_length = 128 * 1024
    # model = DeweySingleVectorWrapper("infgrad/dewey_en_beta", batch_size=batch_size)
    # output_folder = f"./long_embed_benchmark/dewey_en_beta_single_vector_128k"
    # tasks = list(mteb.get_benchmark("LongEmbed"))
    # evaluation = mteb.MTEB(tasks=tasks)
    # evaluation.run(model, output_folder=output_folder, verbosity=2, overwrite_results=False)

    #################  evaluate multi vectors  #################
    batch_size = 4
    max_seq_length = 128 * 1024
    model = DeweyMultiVectorWrapper("infgrad/dewey_en_beta", batch_size=batch_size)
    output_folder = f"./long_embed_benchmark/dewey_en_beta_multi_vectors"

    tasks = list(mteb.get_benchmark("LongEmbed"))
    evaluation = mteb.MTEB(tasks=tasks)
    evaluation.run(model, output_folder=output_folder, verbosity=2, overwrite_results=False)