File size: 8,136 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
"""Topk Retriever."""
import copy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
import tqdm
from sentence_transformers import SentenceTransformer
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerBase
from transformers.file_utils import PaddingStrategy
from opencompass.openicl.icl_dataset_reader import DatasetEncoder
from opencompass.openicl.icl_retriever import BaseRetriever
from opencompass.openicl.utils.logging import get_logger
from opencompass.registry import ICL_RETRIEVERS
logger = get_logger(__name__)
@ICL_RETRIEVERS.register_module()
class TopkRetriever(BaseRetriever):
"""Base class for Topk In-context Learning Retriever, implemented with
basic knn. SentenceTransformer is used to calculate embeddings. Faiss is
used to do the nearest neighbor search.
Args:
dataset (`BaseDataset`): Any BaseDataset instances.
Attributes of ``reader``, ``train`` and ``test`` will be used.
ice_separator (`Optional[str]`): The separator between each in-context
example template when origin `PromptTemplate` is provided. Defaults
to '\n'.
ice_eos_token (`Optional[str]`): The end of sentence token for
in-context example template when origin `PromptTemplate` is
provided. Defaults to '\n'.
ice_num (`Optional[int]`): The number of in-context example template
when origin `PromptTemplate` is provided. Defaults to 1.
sentence_transformers_model_name (`Optional[str]`): The name of the
sentence transformers model. Defaults to 'all-mpnet-base-v2'.
tokenizer_name (`Optional[str]`): The name of the tokenizer. Defaults
to 'gpt2-xl'.
batch_size (`Optional[int]`): The batch size for the dataloader.
Defaults to 1.
"""
model = None
def __init__(self,
dataset,
ice_separator: Optional[str] = '\n',
ice_eos_token: Optional[str] = '\n',
ice_num: Optional[int] = 1,
sentence_transformers_model_name: Optional[
str] = 'all-mpnet-base-v2',
tokenizer_name: Optional[str] = 'gpt2-xl',
batch_size: Optional[int] = 1) -> None:
super().__init__(dataset, ice_separator, ice_eos_token, ice_num)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.batch_size = batch_size
self.tokenizer_name = tokenizer_name
gen_datalist = self.dataset_reader.generate_input_field_corpus(
self.test_ds)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.tokenizer.padding_side = 'right'
self.encode_dataset = DatasetEncoder(gen_datalist,
tokenizer=self.tokenizer)
co = DataCollatorWithPaddingAndCuda(tokenizer=self.tokenizer,
device=self.device)
self.dataloader = DataLoader(self.encode_dataset,
batch_size=self.batch_size,
collate_fn=co)
self.model = SentenceTransformer(sentence_transformers_model_name)
self.model = self.model.to(self.device)
self.model.eval()
self.index = self.create_index()
def create_index(self):
import faiss
self.select_datalist = self.dataset_reader.generate_input_field_corpus(
self.index_ds)
encode_datalist = DatasetEncoder(self.select_datalist,
tokenizer=self.tokenizer)
co = DataCollatorWithPaddingAndCuda(tokenizer=self.tokenizer,
device=self.device)
dataloader = DataLoader(encode_datalist,
batch_size=self.batch_size,
collate_fn=co)
index = faiss.IndexIDMap(
faiss.IndexFlatIP(self.model.get_sentence_embedding_dimension()))
res_list = self.forward(dataloader,
process_bar=True,
information='Creating index for index set...')
id_list = np.array([res['metadata']['id'] for res in res_list])
self.embed_list = np.stack([res['embed'] for res in res_list])
index.add_with_ids(self.embed_list, id_list)
return index
def knn_search(self, ice_num):
res_list = self.forward(self.dataloader,
process_bar=True,
information='Embedding test set...')
rtr_idx_list = [[] for _ in range(len(res_list))]
logger.info('Retrieving data for test set...')
for entry in tqdm.tqdm(res_list, disable=not self.is_main_process):
idx = entry['metadata']['id']
embed = np.expand_dims(entry['embed'], axis=0)
near_ids = self.index.search(embed, ice_num)[1][0].tolist()
rtr_idx_list[idx] = near_ids
return rtr_idx_list
def forward(self, dataloader, process_bar=False, information=''):
res_list = []
_dataloader = copy.deepcopy(dataloader)
if process_bar:
logger.info(information)
_dataloader = tqdm.tqdm(_dataloader,
disable=not self.is_main_process)
for _, entry in enumerate(_dataloader):
with torch.no_grad():
metadata = entry.pop('metadata')
raw_text = self.tokenizer.batch_decode(
entry['input_ids'],
skip_special_tokens=True,
verbose=False)
res = self.model.encode(raw_text, show_progress_bar=False)
res_list.extend([{
'embed': r,
'metadata': m
} for r, m in zip(res, metadata)])
return res_list
def retrieve(self):
"""Retrieve the in-context example index for each test example."""
return self.knn_search(self.ice_num)
class ListWrapper:
def __init__(self, data: List[Any]):
self.data = data
def to(self, device):
return self.data
def ignore_pad_dict(features):
res_dict = {}
if 'metadata' in features[0]:
res_dict['metadata'] = ListWrapper(
[x.pop('metadata') for x in features])
return res_dict
@dataclass
class DataCollatorWithPaddingAndCuda:
tokenizer: PreTrainedTokenizerBase
device: object = None
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = 3000
pad_to_multiple_of: Optional[int] = None
def __call__(
self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
) -> BatchEncoding:
res_dict = ignore_pad_dict(features)
has_labels = 'labels' in features[0]
if has_labels:
labels = [{'input_ids': x.pop('labels')} for x in features]
labels = self.tokenizer.pad(
labels,
padding=True,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_attention_mask=True,
return_tensors='pt',
verbose=False)
# print(features)
batch = self.tokenizer.pad(features,
padding=True,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_attention_mask=True,
return_tensors='pt',
verbose=False)
if has_labels:
batch['labels'] = labels.input_ids
batch.update(res_dict)
if self.device:
batch = batch.to(self.device)
return batch
|