Spaces:
Sleeping
Sleeping
import os | |
import json | |
import tqdm | |
import numpy as np | |
import torch | |
import argparse | |
import torch.nn.functional as F | |
from typing import List, Dict | |
from transformers import AutoModel, AutoTokenizer | |
from transformers.modeling_outputs import BaseModelOutput | |
from mteb import MTEB, AbsTaskRetrieval, DRESModel | |
from utils import pool, logger, move_to_cuda, get_detailed_instruct, get_task_def_by_task_name_and_type, create_batch_dict | |
from model_config import MODEL_NAME_TO_POOL_TYPE, MODEL_NAME_TO_PREFIX_TYPE | |
parser = argparse.ArgumentParser(description='evaluation for BEIR benchmark') | |
parser.add_argument('--model-name-or-path', default='intfloat/e5-small-v2', | |
type=str, metavar='N', help='which model to use') | |
parser.add_argument('--output-dir', default='tmp-outputs/', | |
type=str, metavar='N', help='output directory') | |
parser.add_argument('--doc-as-query', action='store_true', help='use query prefix for passages, only used for Quora as it is a symmetric task') | |
parser.add_argument('--pool-type', default='avg', help='pool type') | |
parser.add_argument('--prefix-type', default='query_or_passage', help='prefix type') | |
parser.add_argument('--dry-run', action='store_true', help='whether to run the script in dry run mode') | |
args = parser.parse_args() | |
base_name: str = args.model_name_or_path.split('/')[-1] | |
args.pool_type = MODEL_NAME_TO_POOL_TYPE.get(base_name, args.pool_type) | |
args.prefix_type = MODEL_NAME_TO_PREFIX_TYPE.get(base_name, args.prefix_type) | |
logger.info('Args: {}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4))) | |
assert args.pool_type in ['cls', 'avg', 'last', 'weightedavg'], 'pool_type should be cls / avg / last' | |
assert args.prefix_type in ['query_or_passage', 'instruction'], 'prefix_type should be query_or_passage / instruction' | |
os.makedirs(args.output_dir, exist_ok=True) | |
class RetrievalModel(DRESModel): | |
# Refer to the code of DRESModel for the methods to overwrite | |
def __init__(self, **kwargs): | |
self.encoder = AutoModel.from_pretrained(args.model_name_or_path, torch_dtype=torch.float16) | |
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) | |
self.prompt = None | |
self.gpu_count = torch.cuda.device_count() | |
if self.gpu_count > 1: | |
self.encoder = torch.nn.DataParallel(self.encoder) | |
self.encoder.cuda() | |
self.encoder.eval() | |
def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray: | |
if args.prefix_type == 'query_or_passage': | |
input_texts = [f'query: {q}' for q in queries] | |
else: | |
input_texts = [self.prompt + q for q in queries] | |
return self._do_encode(input_texts) | |
def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs) -> np.ndarray: | |
if args.doc_as_query: | |
return self.encode_queries([d['text'] for d in corpus], **kwargs) | |
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus] | |
# no need to add prefix for instruct models | |
if args.prefix_type == 'query_or_passage': | |
input_texts = ['passage: {}'.format(t) for t in input_texts] | |
return self._do_encode(input_texts) | |
def _do_encode(self, input_texts: List[str]) -> np.ndarray: | |
encoded_embeds = [] | |
batch_size = 64 * self.gpu_count | |
for start_idx in tqdm.tqdm(range(0, len(input_texts), batch_size), desc='encoding', mininterval=10): | |
batch_input_texts: List[str] = input_texts[start_idx: start_idx + batch_size] | |
batch_dict = create_batch_dict(self.tokenizer, batch_input_texts, always_add_eos=(args.pool_type == 'last')) | |
batch_dict = move_to_cuda(batch_dict) | |
with torch.cuda.amp.autocast(): | |
outputs: BaseModelOutput = self.encoder(**batch_dict) | |
embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type) | |
embeds = F.normalize(embeds, p=2, dim=-1) | |
encoded_embeds.append(embeds.cpu().numpy()) | |
return np.concatenate(encoded_embeds, axis=0) | |
def set_prompt(self, prompt: str): | |
self.prompt = prompt | |
def main(): | |
assert AbsTaskRetrieval.is_dres_compatible(RetrievalModel) | |
model = RetrievalModel() | |
task_names = [t.description["name"] for t in MTEB(task_types=['Retrieval'], task_langs=['en']).tasks] | |
task_names = [t for t in task_names if t != 'MSMARCOv2'] | |
logger.info('Tasks: {}'.format(task_names)) | |
for task in task_names: | |
if args.dry_run and task not in ['SciFact', 'FiQA2018']: | |
continue | |
logger.info('Processing task: {}'.format(task)) | |
if args.prefix_type == 'query_or_passage': | |
args.doc_as_query = task in ['QuoraRetrieval'] | |
else: | |
task_def: str = get_task_def_by_task_name_and_type(task_name=task, task_type='Retrieval') | |
prompt: str = get_detailed_instruct(task_def) | |
model.set_prompt(prompt=prompt) | |
logger.info('Set prompt: {}'.format(prompt)) | |
evaluation = MTEB(tasks=[task], task_langs=['en']) | |
evaluation.run(model, eval_splits=["test" if task not in ['MSMARCO'] else 'dev'], | |
output_folder=args.output_dir) | |
if __name__ == '__main__': | |
main() | |