api-demo
/
opencompass-my-api
/build
/lib
/opencompass
/openicl
/icl_retriever
/icl_topk_retriever.py
"""Topk Retriever.""" | |
import copy | |
from dataclasses import dataclass | |
from typing import Any, Dict, List, Optional, Union | |
import numpy as np | |
import torch | |
import tqdm | |
from sentence_transformers import SentenceTransformer | |
from torch.utils.data import DataLoader | |
from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerBase | |
from transformers.file_utils import PaddingStrategy | |
from opencompass.openicl.icl_dataset_reader import DatasetEncoder | |
from opencompass.openicl.icl_retriever import BaseRetriever | |
from opencompass.openicl.utils.logging import get_logger | |
from opencompass.registry import ICL_RETRIEVERS | |
logger = get_logger(__name__) | |
class TopkRetriever(BaseRetriever): | |
"""Base class for Topk In-context Learning Retriever, implemented with | |
basic knn. SentenceTransformer is used to calculate embeddings. Faiss is | |
used to do the nearest neighbor search. | |
Args: | |
dataset (`BaseDataset`): Any BaseDataset instances. | |
Attributes of ``reader``, ``train`` and ``test`` will be used. | |
ice_separator (`Optional[str]`): The separator between each in-context | |
example template when origin `PromptTemplate` is provided. Defaults | |
to '\n'. | |
ice_eos_token (`Optional[str]`): The end of sentence token for | |
in-context example template when origin `PromptTemplate` is | |
provided. Defaults to '\n'. | |
ice_num (`Optional[int]`): The number of in-context example template | |
when origin `PromptTemplate` is provided. Defaults to 1. | |
sentence_transformers_model_name (`Optional[str]`): The name of the | |
sentence transformers model. Defaults to 'all-mpnet-base-v2'. | |
tokenizer_name (`Optional[str]`): The name of the tokenizer. Defaults | |
to 'gpt2-xl'. | |
batch_size (`Optional[int]`): The batch size for the dataloader. | |
Defaults to 1. | |
""" | |
model = None | |
def __init__(self, | |
dataset, | |
ice_separator: Optional[str] = '\n', | |
ice_eos_token: Optional[str] = '\n', | |
ice_num: Optional[int] = 1, | |
sentence_transformers_model_name: Optional[ | |
str] = 'all-mpnet-base-v2', | |
tokenizer_name: Optional[str] = 'gpt2-xl', | |
batch_size: Optional[int] = 1) -> None: | |
super().__init__(dataset, ice_separator, ice_eos_token, ice_num) | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
self.batch_size = batch_size | |
self.tokenizer_name = tokenizer_name | |
gen_datalist = self.dataset_reader.generate_input_field_corpus( | |
self.test_ds) | |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id | |
self.tokenizer.padding_side = 'right' | |
self.encode_dataset = DatasetEncoder(gen_datalist, | |
tokenizer=self.tokenizer) | |
co = DataCollatorWithPaddingAndCuda(tokenizer=self.tokenizer, | |
device=self.device) | |
self.dataloader = DataLoader(self.encode_dataset, | |
batch_size=self.batch_size, | |
collate_fn=co) | |
self.model = SentenceTransformer(sentence_transformers_model_name) | |
self.model = self.model.to(self.device) | |
self.model.eval() | |
self.index = self.create_index() | |
def create_index(self): | |
import faiss | |
self.select_datalist = self.dataset_reader.generate_input_field_corpus( | |
self.index_ds) | |
encode_datalist = DatasetEncoder(self.select_datalist, | |
tokenizer=self.tokenizer) | |
co = DataCollatorWithPaddingAndCuda(tokenizer=self.tokenizer, | |
device=self.device) | |
dataloader = DataLoader(encode_datalist, | |
batch_size=self.batch_size, | |
collate_fn=co) | |
index = faiss.IndexIDMap( | |
faiss.IndexFlatIP(self.model.get_sentence_embedding_dimension())) | |
res_list = self.forward(dataloader, | |
process_bar=True, | |
information='Creating index for index set...') | |
id_list = np.array([res['metadata']['id'] for res in res_list]) | |
self.embed_list = np.stack([res['embed'] for res in res_list]) | |
index.add_with_ids(self.embed_list, id_list) | |
return index | |
def knn_search(self, ice_num): | |
res_list = self.forward(self.dataloader, | |
process_bar=True, | |
information='Embedding test set...') | |
rtr_idx_list = [[] for _ in range(len(res_list))] | |
logger.info('Retrieving data for test set...') | |
for entry in tqdm.tqdm(res_list, disable=not self.is_main_process): | |
idx = entry['metadata']['id'] | |
embed = np.expand_dims(entry['embed'], axis=0) | |
near_ids = self.index.search(embed, ice_num)[1][0].tolist() | |
rtr_idx_list[idx] = near_ids | |
return rtr_idx_list | |
def forward(self, dataloader, process_bar=False, information=''): | |
res_list = [] | |
_dataloader = copy.deepcopy(dataloader) | |
if process_bar: | |
logger.info(information) | |
_dataloader = tqdm.tqdm(_dataloader, | |
disable=not self.is_main_process) | |
for _, entry in enumerate(_dataloader): | |
with torch.no_grad(): | |
metadata = entry.pop('metadata') | |
raw_text = self.tokenizer.batch_decode( | |
entry['input_ids'], | |
skip_special_tokens=True, | |
verbose=False) | |
res = self.model.encode(raw_text, show_progress_bar=False) | |
res_list.extend([{ | |
'embed': r, | |
'metadata': m | |
} for r, m in zip(res, metadata)]) | |
return res_list | |
def retrieve(self): | |
"""Retrieve the in-context example index for each test example.""" | |
return self.knn_search(self.ice_num) | |
class ListWrapper: | |
def __init__(self, data: List[Any]): | |
self.data = data | |
def to(self, device): | |
return self.data | |
def ignore_pad_dict(features): | |
res_dict = {} | |
if 'metadata' in features[0]: | |
res_dict['metadata'] = ListWrapper( | |
[x.pop('metadata') for x in features]) | |
return res_dict | |
class DataCollatorWithPaddingAndCuda: | |
tokenizer: PreTrainedTokenizerBase | |
device: object = None | |
padding: Union[bool, str, PaddingStrategy] = True | |
max_length: Optional[int] = 3000 | |
pad_to_multiple_of: Optional[int] = None | |
def __call__( | |
self, features: List[Dict[str, Union[List[int], torch.Tensor]]] | |
) -> BatchEncoding: | |
res_dict = ignore_pad_dict(features) | |
has_labels = 'labels' in features[0] | |
if has_labels: | |
labels = [{'input_ids': x.pop('labels')} for x in features] | |
labels = self.tokenizer.pad( | |
labels, | |
padding=True, | |
max_length=self.max_length, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
return_attention_mask=True, | |
return_tensors='pt', | |
verbose=False) | |
# print(features) | |
batch = self.tokenizer.pad(features, | |
padding=True, | |
max_length=self.max_length, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
return_attention_mask=True, | |
return_tensors='pt', | |
verbose=False) | |
if has_labels: | |
batch['labels'] = labels.input_ids | |
batch.update(res_dict) | |
if self.device: | |
batch = batch.to(self.device) | |
return batch | |