api-demo
/
opencompass-my-api
/build
/lib
/opencompass
/openicl
/icl_retriever
/icl_mdl_retriever.py
"""MDL Retriever.""" | |
from typing import List, Optional | |
import numpy as np | |
import torch | |
import tqdm | |
from transformers import AutoModelForCausalLM | |
from opencompass.openicl.icl_prompt_template import PromptTemplate | |
from opencompass.openicl.icl_retriever.icl_topk_retriever import TopkRetriever | |
from opencompass.openicl.utils.logging import get_logger | |
from opencompass.registry import ICL_PROMPT_TEMPLATES, ICL_RETRIEVERS | |
logger = get_logger(__name__) | |
class MDLRetriever(TopkRetriever): | |
"""MDL Retriever, subclass of `TopkRetriever`. MDL is a abbreviation of | |
Minimum Description Length, specially designed for ppl evaluation. You may | |
refer to the paper for more details: https://arxiv.org/pdf/2212.10375.pdf. | |
Args: | |
dataset (`BaseDataset`): Any BaseDataset instances. | |
Attributes of ``reader``, ``train`` and ``test`` will be used. | |
ice_separator (`Optional[str]`): The separator between each in-context | |
example template when origin `PromptTemplate` is provided. Defaults | |
to '\n'. | |
ice_eos_token (`Optional[str]`): The end of sentence token for | |
in-context example template when origin `PromptTemplate` is | |
provided. Defaults to '\n'. | |
ice_num (`Optional[int]`): The number of in-context example template | |
when origin `PromptTemplate` is provided. Defaults to 1. | |
sentence_transformers_model_name (`Optional[str]`): The name of the | |
sentence transformers model. Defaults to 'all-mpnet-base-v2'. | |
tokenizer_name (`Optional[str]`): The name of the tokenizer. Defaults | |
to 'gpt2-xl'. | |
batch_size (`Optional[int]`): The batch size for the dataloader. | |
Defaults to 1. | |
candidate_num (`Optional[int]`): The number of candidates to retrieve | |
for each example. Defaults to 1. | |
ce_model_name (`Optional[str]`): The name of the model for calculating | |
MDL. Defaults to 'gpt2-xl'. | |
select_time (`Optional[int]`): The number of times to select MDL. | |
Defaults to 5. | |
ice_template (`Optional[PromptTemplate]`): The template for in-context | |
example. Defaults to None. | |
prompt_template (`Optional[PromptTemplate]`): The template for prompt. | |
Defaults to None. | |
labels (`Optional[List]`): The labels for calculating MDL. Defaults to | |
None. | |
seed (`Optional[int]`): The seed for random. Defaults to 1. | |
""" | |
metric_model = None | |
def __init__(self, | |
dataset, | |
ice_separator: Optional[str] = '\n', | |
ice_eos_token: Optional[str] = '\n', | |
ice_num: Optional[int] = 1, | |
sentence_transformers_model_name: Optional[ | |
str] = 'all-mpnet-base-v2', | |
tokenizer_name: Optional[str] = 'gpt2-xl', | |
batch_size: Optional[int] = 1, | |
candidate_num: Optional[int] = 1, | |
ce_model_name: Optional[str] = 'gpt2-xl', | |
select_time: Optional[int] = 5, | |
ice_template: Optional[PromptTemplate] = None, | |
prompt_template: Optional[PromptTemplate] = None, | |
labels: Optional[List] = None, | |
seed: Optional[int] = 1) -> None: | |
super().__init__(dataset, ice_separator, ice_eos_token, ice_num, | |
sentence_transformers_model_name, tokenizer_name, | |
batch_size) | |
self.ce_model_name = ce_model_name | |
self.candidate_num = candidate_num | |
self.select_time = select_time | |
self.ice_template = ICL_PROMPT_TEMPLATES.build(ice_template) | |
if prompt_template is not None: | |
self.prompt_template = ICL_PROMPT_TEMPLATES.build(prompt_template) | |
else: | |
self.prompt_template = None | |
self.labels = labels | |
self.seed = seed | |
def topk_search(self): | |
np.random.seed(self.seed) | |
res_list = self.forward(self.dataloader) | |
rtr_idx_list = [[] for _ in range(len(res_list))] | |
logger.info('Retrieving data for test set...') | |
for entry in tqdm.tqdm(res_list, disable=not self.is_main_process): | |
idx = entry['metadata']['id'] | |
embed = np.expand_dims(entry['embed'], axis=0) | |
near_ids = self.index.search( | |
embed, min(self.candidate_num, | |
len(self.index_ds)))[1][0].tolist() | |
candidates = [] | |
mdl_scores = [] | |
for j in range(self.select_time): | |
if j == 0: | |
rand_idx_list = near_ids[:self.ice_num] | |
else: | |
rand_idx_list = np.random.choice(near_ids, | |
self.ice_num, | |
replace=False) | |
rand_idx_list = [int(i) for i in rand_idx_list] | |
candidates.append(rand_idx_list) | |
ice = self.generate_ice(rand_idx_list, | |
ice_template=self.ice_template) | |
ice = str(ice) | |
mask_length = len( | |
self.tokenizer(ice + self.ice_eos_token, | |
verbose=False)['input_ids']) | |
if self.labels is None: | |
labels = self.get_labels(self.ice_template, | |
self.prompt_template) | |
else: | |
labels = self.labels | |
prompt_list = [] | |
for label in labels: | |
prompt = self.generate_label_prompt( | |
idx, ice, label, self.ice_template, | |
self.prompt_template) | |
prompt = str(prompt) | |
prompt_list.append(prompt) | |
loss_list = self.cal_ce(prompt_list, mask_length=mask_length) | |
probs = np.exp(-np.array(loss_list)) | |
normalized_probs = probs / probs.sum(0, keepdims=True) | |
neg_entropy = -entropy(normalized_probs, label_dim=0) | |
mdl_scores.append(neg_entropy) | |
rtr_idx_list[idx] = candidates[mdl_scores.index(max(mdl_scores))] | |
rtr_idx_list[idx] = [int(i) for i in rtr_idx_list[idx]] | |
return rtr_idx_list | |
def retrieve(self): | |
"""Retrieve the in-context example index for each test example.""" | |
return self.topk_search() | |
def cal_ce(self, input_texts: List[str], mask_length=None): | |
if self.metric_model is None: | |
logger.info( | |
f'Load model {self.ce_model_name} for calculating MDL...') | |
self.metric_model = AutoModelForCausalLM.from_pretrained( | |
self.ce_model_name) | |
self.metric_model.to(self.device) | |
inputs = self.tokenizer(input_texts, | |
padding=True, | |
return_tensors='pt', | |
truncation=True) | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
outputs = self.metric_model(**inputs) | |
shift_logits = outputs.logits[..., :-1, :].contiguous() | |
shift_labels = inputs['input_ids'][..., 1:].contiguous() | |
loss_fct = torch.nn.CrossEntropyLoss( | |
reduction='none', ignore_index=self.tokenizer.pad_token_id) | |
shift_logits = shift_logits.view(-1, shift_logits.size(-1)) | |
loss = loss_fct(shift_logits, | |
shift_labels.view(-1)).view(shift_labels.size()) | |
if mask_length is not None: | |
mask = torch.cat([ | |
torch.zeros([loss.shape[0], mask_length], dtype=torch.float), | |
torch.ones([loss.shape[0], loss.shape[-1] - mask_length], | |
dtype=torch.float) | |
], -1) | |
mask = mask.to(self.device) | |
loss = torch.mul(mask, loss) | |
lens = (inputs['input_ids'] != | |
self.tokenizer.pad_token_id).sum(-1).cpu().numpy() | |
if mask_length is not None: | |
lens -= mask_length | |
ce_loss = loss.sum(-1).cpu().detach().numpy() / lens | |
return ce_loss | |
def entropy(probs: np.array, label_dim: int = 0, mask=None): | |
if mask is None: | |
return -(probs * np.log(probs)).sum(label_dim) | |
return -(mask * probs * np.log(probs)).sum(label_dim) | |