import torch import logging from torch import Tensor from transformers import PreTrainedTokenizerFast, BatchEncoding from typing import Mapping, Dict, List def _setup_logger(): log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") logger = logging.getLogger() logger.setLevel(logging.INFO) console_handler = logging.StreamHandler() console_handler.setFormatter(log_format) logger.handlers = [console_handler] return logger logger = _setup_logger() def move_to_cuda(sample): if len(sample) == 0: return {} def _move_to_cuda(maybe_tensor): if torch.is_tensor(maybe_tensor): return maybe_tensor.cuda(non_blocking=True) elif isinstance(maybe_tensor, dict): return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()} elif isinstance(maybe_tensor, list): return [_move_to_cuda(x) for x in maybe_tensor] elif isinstance(maybe_tensor, tuple): return tuple([_move_to_cuda(x) for x in maybe_tensor]) elif isinstance(maybe_tensor, Mapping): return type(maybe_tensor)({k: _move_to_cuda(v) for k, v in maybe_tensor.items()}) else: return maybe_tensor return _move_to_cuda(sample) def pool(last_hidden_states: Tensor, attention_mask: Tensor, pool_type: str) -> Tensor: last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) if pool_type == "avg": emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] elif pool_type == "weightedavg": # position-weighted mean pooling from SGPT (https://arxiv.org/abs/2202.08904) attention_mask *= attention_mask.cumsum(dim=1) # [0,1,1,1,0,0] -> [0,1,2,3,0,0] s = torch.sum(last_hidden * attention_mask.unsqueeze(-1).float(), dim=1) d = attention_mask.sum(dim=1, keepdim=True).float() emb = s / d elif pool_type == "cls": emb = last_hidden[:, 0] elif pool_type == "last": left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: emb = last_hidden[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden.shape[0] emb = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths] else: raise ValueError(f"pool_type {pool_type} not supported") return emb def create_batch_dict(tokenizer: PreTrainedTokenizerFast, input_texts: List[str], always_add_eos: bool, max_length: int = 512) -> BatchEncoding: if not always_add_eos: return tokenizer( input_texts, max_length=max_length, padding=True, pad_to_multiple_of=8, return_token_type_ids=False, truncation=True, return_tensors='pt' ) else: batch_dict = tokenizer( input_texts, max_length=max_length - 1, return_token_type_ids=False, return_attention_mask=False, padding=False, truncation=True ) # append eos_token_id to every input_ids batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']] return tokenizer.pad( batch_dict, padding=True, pad_to_multiple_of=8, return_attention_mask=True, return_tensors="pt", ) def get_task_def_by_task_name_and_type(task_name: str, task_type: str) -> str: if task_type in ['STS']: return "Retrieve semantically similar text." if task_type in ['Summarization']: return "Given a news summary, retrieve other semantically similar summaries" if task_type in ['BitextMining']: return "Retrieve parallel sentences." if task_type in ['Classification']: task_name_to_instruct: Dict[str, str] = { 'AmazonCounterfactualClassification': 'Classify a given Amazon customer review text as either counterfactual or not-counterfactual', 'AmazonPolarityClassification': 'Classify Amazon reviews into positive or negative sentiment', 'AmazonReviewsClassification': 'Classify the given Amazon review into its appropriate rating category', 'Banking77Classification': 'Given a online banking query, find the corresponding intents', 'EmotionClassification': 'Classify the emotion expressed in the given Twitter message into one of the six emotions: anger, fear, joy, love, sadness, and surprise', 'ImdbClassification': 'Classify the sentiment expressed in the given movie review text from the IMDB dataset', 'MassiveIntentClassification': 'Given a user utterance as query, find the user intents', 'MassiveScenarioClassification': 'Given a user utterance as query, find the user scenarios', 'MTOPDomainClassification': 'Classify the intent domain of the given utterance in task-oriented conversation', 'MTOPIntentClassification': 'Classify the intent of the given utterance in task-oriented conversation', 'ToxicConversationsClassification': 'Classify the given comments as either toxic or not toxic', 'TweetSentimentExtractionClassification': 'Classify the sentiment of a given tweet as either positive, negative, or neutral', # C-MTEB eval instructions 'TNews': 'Classify the fine-grained category of the given news title', 'IFlyTek': 'Given an App description text, find the appropriate fine-grained category', 'MultilingualSentiment': 'Classify sentiment of the customer review into positive, neutral, or negative', 'JDReview': 'Classify the customer review for iPhone on e-commerce platform into positive or negative', 'OnlineShopping': 'Classify the customer review for online shopping into positive or negative', 'Waimai': 'Classify the customer review from a food takeaway platform into positive or negative', } return task_name_to_instruct[task_name] if task_type in ['Clustering']: task_name_to_instruct: Dict[str, str] = { 'ArxivClusteringP2P': 'Identify the main and secondary category of Arxiv papers based on the titles and abstracts', 'ArxivClusteringS2S': 'Identify the main and secondary category of Arxiv papers based on the titles', 'BiorxivClusteringP2P': 'Identify the main category of Biorxiv papers based on the titles and abstracts', 'BiorxivClusteringS2S': 'Identify the main category of Biorxiv papers based on the titles', 'MedrxivClusteringP2P': 'Identify the main category of Medrxiv papers based on the titles and abstracts', 'MedrxivClusteringS2S': 'Identify the main category of Medrxiv papers based on the titles', 'RedditClustering': 'Identify the topic or theme of Reddit posts based on the titles', 'RedditClusteringP2P': 'Identify the topic or theme of Reddit posts based on the titles and posts', 'StackExchangeClustering': 'Identify the topic or theme of StackExchange posts based on the titles', 'StackExchangeClusteringP2P': 'Identify the topic or theme of StackExchange posts based on the given paragraphs', 'TwentyNewsgroupsClustering': 'Identify the topic or theme of the given news articles', # C-MTEB eval instructions 'CLSClusteringS2S': 'Identify the main category of scholar papers based on the titles', 'CLSClusteringP2P': 'Identify the main category of scholar papers based on the titles and abstracts', 'ThuNewsClusteringS2S': 'Identify the topic or theme of the given news articles based on the titles', 'ThuNewsClusteringP2P': 'Identify the topic or theme of the given news articles based on the titles and contents', } return task_name_to_instruct[task_name] if task_type in ['Reranking', 'PairClassification']: task_name_to_instruct: Dict[str, str] = { 'AskUbuntuDupQuestions': 'Retrieve duplicate questions from AskUbuntu forum', 'MindSmallReranking': 'Retrieve relevant news articles based on user browsing history', 'SciDocsRR': 'Given a title of a scientific paper, retrieve the titles of other relevant papers', 'StackOverflowDupQuestions': 'Retrieve duplicate questions from StackOverflow forum', 'SprintDuplicateQuestions': 'Retrieve duplicate questions from Sprint forum', 'TwitterSemEval2015': 'Retrieve tweets that are semantically similar to the given tweet', 'TwitterURLCorpus': 'Retrieve tweets that are semantically similar to the given tweet', # C-MTEB eval instructions 'T2Reranking': 'Given a Chinese search query, retrieve web passages that answer the question', 'MMarcoReranking': 'Given a Chinese search query, retrieve web passages that answer the question', 'CMedQAv1': 'Given a Chinese community medical question, retrieve replies that best answer the question', 'CMedQAv2': 'Given a Chinese community medical question, retrieve replies that best answer the question', 'Ocnli': 'Retrieve semantically similar text.', 'Cmnli': 'Retrieve semantically similar text.', } return task_name_to_instruct[task_name] if task_type in ['Retrieval']: if task_name.lower().startswith('cqadupstack'): return 'Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question' task_name_to_instruct: Dict[str, str] = { 'ArguAna': 'Given a claim, find documents that refute the claim', 'ClimateFEVER': 'Given a claim about climate change, retrieve documents that support or refute the claim', 'DBPedia': 'Given a query, retrieve relevant entity descriptions from DBPedia', 'FEVER': 'Given a claim, retrieve documents that support or refute the claim', 'FiQA2018': 'Given a financial question, retrieve user replies that best answer the question', 'HotpotQA': 'Given a multi-hop question, retrieve documents that can help answer the question', 'MSMARCO': 'Given a web search query, retrieve relevant passages that answer the query', 'NFCorpus': 'Given a question, retrieve relevant documents that best answer the question', 'NQ': 'Given a question, retrieve Wikipedia passages that answer the question', 'QuoraRetrieval': 'Given a question, retrieve questions that are semantically equivalent to the given question', 'SCIDOCS': 'Given a scientific paper title, retrieve paper abstracts that are cited by the given paper', 'SciFact': 'Given a scientific claim, retrieve documents that support or refute the claim', 'Touche2020': 'Given a question, retrieve detailed and persuasive arguments that answer the question', 'TRECCOVID': 'Given a query on COVID-19, retrieve documents that answer the query', # C-MTEB eval instructions 'T2Retrieval': 'Given a Chinese search query, retrieve web passages that answer the question', 'MMarcoRetrieval': 'Given a web search query, retrieve relevant passages that answer the query', 'DuRetrieval': 'Given a Chinese search query, retrieve web passages that answer the question', 'CovidRetrieval': 'Given a question on COVID-19, retrieve news articles that answer the question', 'CmedqaRetrieval': 'Given a Chinese community medical question, retrieve replies that best answer the question', 'EcomRetrieval': 'Given a user query from an e-commerce website, retrieve description sentences of relevant products', 'MedicalRetrieval': 'Given a medical question, retrieve user replies that best answer the question', 'VideoRetrieval': 'Given a video search query, retrieve the titles of relevant videos', } # add lower case keys to match some beir names task_name_to_instruct.update({k.lower(): v for k, v in task_name_to_instruct.items()}) # other cases where lower case match still doesn't work task_name_to_instruct['trec-covid'] = task_name_to_instruct['TRECCOVID'] task_name_to_instruct['climate-fever'] = task_name_to_instruct['ClimateFEVER'] task_name_to_instruct['dbpedia-entity'] = task_name_to_instruct['DBPedia'] task_name_to_instruct['webis-touche2020'] = task_name_to_instruct['Touche2020'] task_name_to_instruct['fiqa'] = task_name_to_instruct['FiQA2018'] task_name_to_instruct['quora'] = task_name_to_instruct['QuoraRetrieval'] # for miracl evaluation task_name_to_instruct['miracl'] = 'Given a question, retrieve Wikipedia passages that answer the question' return task_name_to_instruct[task_name] raise ValueError(f"No instruction config for task {task_name} with type {task_type}") def get_detailed_instruct(task_description: str) -> str: if not task_description: return '' return 'Instruct: {}\nQuery: '.format(task_description)