|
import os
|
|
|
|
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
|
os.environ["OPENBLAS_NUM_THREADS"] = "32"
|
|
import numpy as np
|
|
import torch
|
|
import mteb
|
|
from mteb.encoder_interface import PromptType
|
|
from sentence_transformers import SentenceTransformer
|
|
from mteb.models.wrapper import Wrapper
|
|
from typing import Sequence
|
|
from typing import Any
|
|
from transformers import AutoTokenizer, AutoModel
|
|
|
|
|
|
class DeweySingleVectorWrapper:
|
|
def __init__(self, model_dir, batch_size: int = 8):
|
|
self.model = SentenceTransformer(
|
|
model_dir,
|
|
trust_remote_code=True,
|
|
model_kwargs={
|
|
"torch_dtype": torch.bfloat16,
|
|
"attn_implementation": "flash_attention_2"
|
|
},
|
|
config_kwargs={"single_vector_type": "mean"}
|
|
).cuda().bfloat16().eval()
|
|
self.model.max_seq_length = max_seq_length
|
|
self.pool = self.model.start_multi_process_pool()
|
|
self.batch_size = batch_size
|
|
|
|
def encode(
|
|
self,
|
|
sentences: list[str],
|
|
task_name: str,
|
|
prompt_type: PromptType | None = None,
|
|
**kwargs,
|
|
) -> np.ndarray:
|
|
if prompt_type.value == "query":
|
|
prompt = RETRIEVE_Q_PROMPT
|
|
else:
|
|
prompt = RETRIEVE_P_PROMPT
|
|
vectors = self.model.encode_multi_process(
|
|
sentences=sentences,
|
|
pool=self.pool,
|
|
show_progress_bar=True,
|
|
batch_size=self.batch_size,
|
|
normalize_embeddings=True,
|
|
prompt=prompt,
|
|
precision="float32"
|
|
)
|
|
return vectors
|
|
|
|
|
|
class DeweyMultiVectorWrapper(Wrapper):
|
|
def __init__(
|
|
self,
|
|
model_dir: str,
|
|
batch_size: int = 8,
|
|
*args,
|
|
**kwargs,
|
|
) -> None:
|
|
self.model = AutoModel.from_pretrained(
|
|
model_dir,
|
|
trust_remote_code=True,
|
|
attn_implementation="flash_attention_2"
|
|
).cuda().bfloat16()
|
|
self.batch_size = batch_size
|
|
self.model.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
|
|
|
def encode(
|
|
self,
|
|
sentences: Sequence[str],
|
|
*,
|
|
task_name: str,
|
|
prompt_type: PromptType | None = None,
|
|
**kwargs: Any,
|
|
) -> np.ndarray:
|
|
|
|
if prompt_type.value == "query":
|
|
prompt = RETRIEVE_Q_PROMPT
|
|
else:
|
|
prompt = RETRIEVE_P_PROMPT
|
|
if prompt_type.value == "query":
|
|
pred = self.model.encode(
|
|
sentences=list(sentences),
|
|
use_cuda=True,
|
|
show_progress_bar=True,
|
|
chunk_size=-1,
|
|
chunk_overlap=32,
|
|
convert_to_tensor=True,
|
|
max_seq_length=max_seq_length,
|
|
batch_size=self.batch_size,
|
|
normalize_embeddings=True,
|
|
prompt=prompt,
|
|
fast_chunk=False
|
|
|
|
)[0]
|
|
|
|
pred = [vecs[1:2, :] for vecs in pred]
|
|
else:
|
|
pred = self.model.encode(
|
|
sentences=list(sentences),
|
|
use_cuda=True,
|
|
show_progress_bar=True,
|
|
chunk_size=256,
|
|
chunk_overlap=32,
|
|
convert_to_tensor=True,
|
|
max_seq_length=max_seq_length,
|
|
batch_size=self.batch_size,
|
|
normalize_embeddings=True,
|
|
prompt=prompt,
|
|
fast_chunk=True,
|
|
)[0]
|
|
|
|
pred = torch.nn.utils.rnn.pad_sequence(pred, batch_first=True, padding_value=0)
|
|
return pred.cpu().numpy()
|
|
|
|
def similarity(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
|
if not isinstance(a, torch.Tensor):
|
|
a = torch.tensor(a, dtype=torch.float32)
|
|
|
|
if not isinstance(b, torch.Tensor):
|
|
b = torch.tensor(b, dtype=torch.float32)
|
|
|
|
if len(a.shape) == 2:
|
|
a = a.unsqueeze(0)
|
|
|
|
if len(b.shape) == 2:
|
|
b = b.unsqueeze(0)
|
|
|
|
scores = torch.einsum(
|
|
"ash,bth->abst",
|
|
a,
|
|
b,
|
|
)
|
|
|
|
return scores.max(axis=-1).values.sum(axis=-1)
|
|
|
|
|
|
RETRIEVE_Q_PROMPT = "<|START_INSTRUCTION|>Answer the question<|END_INSTRUCTION|>"
|
|
RETRIEVE_P_PROMPT = "<|START_INSTRUCTION|>Candidate document<|END_INSTRUCTION|>"
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = 4
|
|
max_seq_length = 128 * 1024
|
|
model = DeweyMultiVectorWrapper("infgrad/dewey_en_beta", batch_size=batch_size)
|
|
output_folder = f"./long_embed_benchmark/dewey_en_beta_multi_vectors"
|
|
|
|
tasks = list(mteb.get_benchmark("LongEmbed"))
|
|
evaluation = mteb.MTEB(tasks=tasks)
|
|
evaluation.run(model, output_folder=output_folder, verbosity=2, overwrite_results=False)
|
|
|