File size: 6,396 Bytes
fbc1304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["OPENBLAS_NUM_THREADS"] = "32"
import mteb
import torch
import numpy as np
from mteb.encoder_interface import PromptType
from sentence_transformers import SentenceTransformer

TASK_NAME2TYPE = {
    'ArguAna': 'Retrieval', 'ArXivHierarchicalClusteringP2P': 'Clustering',
    'ArXivHierarchicalClusteringS2S': 'Clustering', 'AskUbuntuDupQuestions': 'Reranking',
    'BIOSSES': 'STS', 'Banking77Classification': 'Classification',
    'BiorxivClusteringP2P.v2': 'Clustering', 'CQADupstackGamingRetrieval': 'Retrieval',
    'CQADupstackUnixRetrieval': 'Retrieval', 'ClimateFEVERHardNegatives': 'Retrieval',
    'FEVERHardNegatives': 'Retrieval', 'FiQA2018': 'Retrieval', 'HotpotQAHardNegatives': 'Retrieval',
    'ImdbClassification': 'Classification', 'MTOPDomainClassification': 'Classification',
    'MassiveIntentClassification': 'Classification', 'MassiveScenarioClassification': 'Classification',
    'MedrxivClusteringP2P.v2': 'Clustering', 'MedrxivClusteringS2S.v2': 'Clustering',
    'MindSmallReranking': 'Reranking', 'SCIDOCS': 'Retrieval', 'SICK-R': 'STS', 'STS12': 'STS',
    'STS13': 'STS', 'STS14': 'STS', 'STS15': 'STS', 'STSBenchmark': 'STS',
    'SprintDuplicateQuestions': 'PairClassification', 'StackExchangeClustering.v2': 'Clustering',
    'StackExchangeClusteringP2P.v2': 'Clustering', 'TRECCOVID': 'Retrieval',
    'Touche2020Retrieval.v3': 'Retrieval', 'ToxicConversationsClassification': 'Classification',
    'TweetSentimentExtractionClassification': 'Classification',
    'TwentyNewsgroupsClustering.v2': 'Clustering', 'TwitterSemEval2015': 'PairClassification',
    'TwitterURLCorpus': 'PairClassification', 'SummEvalSummarization.v2': 'Summarization',
    'AmazonCounterfactualClassification': 'Classification', 'STS17': 'STS', 'STS22.v2': 'STS'
}

RETRIEVE_Q_PROMPT = "<|START_INSTRUCTION|>Answer the question<|END_INSTRUCTION|>"
RETRIEVE_P_PROMPT = "<|START_INSTRUCTION|>Candidate document<|END_INSTRUCTION|>"
STS_PROMPT = "<|START_INSTRUCTION|>Generate semantically similar text<|END_INSTRUCTION|>"

TASK_NAME2PROMPT = {
    # Classification
    "Banking77Classification": "<|START_INSTRUCTION|>Classify text into intents<|END_INSTRUCTION|>",
    "ImdbClassification": "<|START_INSTRUCTION|>Classify text into sentiment<|END_INSTRUCTION|>",
    "MTOPDomainClassification": "<|START_INSTRUCTION|>Classify text into intent domain<|END_INSTRUCTION|>",
    "MassiveIntentClassification": "<|START_INSTRUCTION|>Classify text into user intents<|END_INSTRUCTION|>",
    "MassiveScenarioClassification": "<|START_INSTRUCTION|>Classify text into user scenarios<|END_INSTRUCTION|>",
    "ToxicConversationsClassification": "<|START_INSTRUCTION|>Classify text into toxic or not toxic<|END_INSTRUCTION|>",
    "TweetSentimentExtractionClassification": "<|START_INSTRUCTION|>Classify text into positive, negative, or neutral sentiment<|END_INSTRUCTION|>",
    "AmazonCounterfactualClassification": "<|START_INSTRUCTION|>Classify text into counterfactual or not-counterfactual<|END_INSTRUCTION|>",

    # Clustering
    "ArXivHierarchicalClusteringP2P": "<|START_INSTRUCTION|>Output main and secondary category of Arxiv papers based on the titles and abstracts<|END_INSTRUCTION|>",
    "ArXivHierarchicalClusteringS2S": "<|START_INSTRUCTION|>Output main and secondary category of Arxiv papers based on the titles<|END_INSTRUCTION|>",
    "BiorxivClusteringP2P.v2": "<|START_INSTRUCTION|>Output main category of Biorxiv papers based on the titles and abstracts<|END_INSTRUCTION|>",
    "MedrxivClusteringP2P.v2": "<|START_INSTRUCTION|>Output main category of Medrxiv papers based on the titles and abstracts<|END_INSTRUCTION|>",
    "MedrxivClusteringS2S.v2": "<|START_INSTRUCTION|>Output main category of Medrxiv papers based on the titles<|END_INSTRUCTION|>",
    "StackExchangeClustering.v2": "<|START_INSTRUCTION|>Output topic or theme of StackExchange posts based on the titles<|END_INSTRUCTION|>",
    "StackExchangeClusteringP2P.v2": "<|START_INSTRUCTION|>Output topic or theme of StackExchange posts based on the given paragraphs<|END_INSTRUCTION|>",
    "TwentyNewsgroupsClustering.v2": "<|START_INSTRUCTION|>Output topic or theme of news articles<|END_INSTRUCTION|>",
}


class DeweyWrapper:
    def __init__(self, model_dir, max_seq_length: int = 1536, batch_size: int = 8):
        self.model = SentenceTransformer(
            model_dir,
            trust_remote_code=True,
            model_kwargs={
                "torch_dtype": torch.bfloat16,  # fp16 容易计算出nan
                "attn_implementation": "flash_attention_2"
            },
            config_kwargs={"single_vector_type": "cls_add_mean"}
        ).cuda().bfloat16().eval()
        self.model.max_seq_length = max_seq_length
        self.pool = self.model.start_multi_process_pool()
        self.batch_size = batch_size

    def encode(

            self,

            sentences: list[str],

            task_name: str,

            prompt_type: PromptType | None = None,

            **kwargs,

    ) -> np.ndarray:
        task_type = TASK_NAME2TYPE[task_name]
        if task_type == "Retrieval":
            if prompt_type.value == "query":
                prompt = RETRIEVE_Q_PROMPT
            else:
                prompt = RETRIEVE_P_PROMPT
        elif task_type in ["STS", "PairClassification", "Summarization", "Reranking"]:
            prompt = STS_PROMPT
        else:
            prompt = TASK_NAME2PROMPT[task_name]
        vectors = self.model.encode_multi_process(
            sentences=sentences,
            pool=self.pool,
            show_progress_bar=True,
            batch_size=self.batch_size,
            normalize_embeddings=True,
            prompt=prompt,
            precision="float32"
        )
        return vectors


if __name__ == "__main__":
    max_seq_length = 1536
    batch_szie = 8
    model_dir_or_name = "infgrad/dewey_en_beta"
    output_folder = f"./mteb_eng_results/dewey_en_beta"
    model = DeweyWrapper(model_dir_or_name, max_seq_length=max_seq_length, batch_size=batch_szie)

    tasks = list(mteb.get_benchmark("MTEB(eng, v2)"))
    evaluation = mteb.MTEB(tasks=tasks)
    evaluation.run(model, output_folder=output_folder, verbosity=2, overwrite_results=False)