import os os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" os.environ["OPENBLAS_NUM_THREADS"] = "32" import numpy as np import torch import mteb from mteb.encoder_interface import PromptType from sentence_transformers import SentenceTransformer from mteb.models.wrapper import Wrapper from typing import Sequence from typing import Any from transformers import AutoTokenizer, AutoModel class DeweySingleVectorWrapper: def __init__(self, model_dir, batch_size: int = 8): self.model = SentenceTransformer( model_dir, trust_remote_code=True, model_kwargs={ "torch_dtype": torch.bfloat16, # fp16 容易计算出nan "attn_implementation": "flash_attention_2" }, config_kwargs={"single_vector_type": "mean"} ).cuda().bfloat16().eval() self.model.max_seq_length = max_seq_length self.pool = self.model.start_multi_process_pool() self.batch_size = batch_size def encode( self, sentences: list[str], task_name: str, prompt_type: PromptType | None = None, **kwargs, ) -> np.ndarray: if prompt_type.value == "query": prompt = RETRIEVE_Q_PROMPT else: prompt = RETRIEVE_P_PROMPT vectors = self.model.encode_multi_process( sentences=sentences, pool=self.pool, show_progress_bar=True, batch_size=self.batch_size, normalize_embeddings=True, prompt=prompt, precision="float32" ) return vectors class DeweyMultiVectorWrapper(Wrapper): def __init__( self, model_dir: str, batch_size: int = 8, *args, **kwargs, ) -> None: self.model = AutoModel.from_pretrained( model_dir, trust_remote_code=True, attn_implementation="flash_attention_2" ).cuda().bfloat16() self.batch_size = batch_size self.model.tokenizer = AutoTokenizer.from_pretrained(model_dir) def encode( self, sentences: Sequence[str], *, task_name: str, prompt_type: PromptType | None = None, **kwargs: Any, ) -> np.ndarray: if prompt_type.value == "query": prompt = RETRIEVE_Q_PROMPT else: prompt = RETRIEVE_P_PROMPT if prompt_type.value == "query": pred = self.model.encode( sentences=list(sentences), use_cuda=True, show_progress_bar=True, chunk_size=-1, chunk_overlap=32, convert_to_tensor=True, max_seq_length=max_seq_length, batch_size=self.batch_size, normalize_embeddings=True, prompt=prompt, fast_chunk=False )[0] # query vector do not need multi vector, we only use mean as final one vector pred = [vecs[1:2, :] for vecs in pred] else: pred = self.model.encode( sentences=list(sentences), use_cuda=True, show_progress_bar=True, chunk_size=256, chunk_overlap=32, convert_to_tensor=True, max_seq_length=max_seq_length, batch_size=self.batch_size, normalize_embeddings=True, prompt=prompt, fast_chunk=True, )[0] pred = torch.nn.utils.rnn.pad_sequence(pred, batch_first=True, padding_value=0) return pred.cpu().numpy() def similarity(self, a: np.ndarray, b: np.ndarray) -> np.ndarray: if not isinstance(a, torch.Tensor): a = torch.tensor(a, dtype=torch.float32) if not isinstance(b, torch.Tensor): b = torch.tensor(b, dtype=torch.float32) if len(a.shape) == 2: a = a.unsqueeze(0) if len(b.shape) == 2: b = b.unsqueeze(0) scores = torch.einsum( "ash,bth->abst", a, b, ) return scores.max(axis=-1).values.sum(axis=-1) RETRIEVE_Q_PROMPT = "<|START_INSTRUCTION|>Answer the question<|END_INSTRUCTION|>" RETRIEVE_P_PROMPT = "<|START_INSTRUCTION|>Candidate document<|END_INSTRUCTION|>" if __name__ == "__main__": ################# evaluate single vector ################# # batch_size = 4 # max_seq_length = 128 * 1024 # model = DeweySingleVectorWrapper("infgrad/dewey_en_beta", batch_size=batch_size) # output_folder = f"./long_embed_benchmark/dewey_en_beta_single_vector_128k" # tasks = list(mteb.get_benchmark("LongEmbed")) # evaluation = mteb.MTEB(tasks=tasks) # evaluation.run(model, output_folder=output_folder, verbosity=2, overwrite_results=False) ################# evaluate multi vectors ################# batch_size = 4 max_seq_length = 128 * 1024 model = DeweyMultiVectorWrapper("infgrad/dewey_en_beta", batch_size=batch_size) output_folder = f"./long_embed_benchmark/dewey_en_beta_multi_vectors" tasks = list(mteb.get_benchmark("LongEmbed")) evaluation = mteb.MTEB(tasks=tasks) evaluation.run(model, output_folder=output_folder, verbosity=2, overwrite_results=False)