|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Optional |
|
|
|
import torch |
|
import transformers |
|
from torch.utils.data import Dataset, Sampler |
|
from transformers.tokenization_utils_base import BatchEncoding |
|
from transformers.trainer import (LengthGroupedSampler, RandomSampler, |
|
has_length) |
|
from transformers.trainer_pt_utils import logger |
|
|
|
|
|
|
|
def split_to_even_chunks(indices, lengths, num_chunks): |
|
""" |
|
Split a list of indices into `chunks` chunks of roughly equal lengths. |
|
""" |
|
|
|
if len(indices) % num_chunks != 0: |
|
return [indices[i::num_chunks] for i in range(num_chunks)] |
|
|
|
num_indices_per_chunk = len(indices) // num_chunks |
|
|
|
chunks = [[] for _ in range(num_chunks)] |
|
chunks_lengths = [0 for _ in range(num_chunks)] |
|
for index in indices: |
|
shortest_chunk = chunks_lengths.index(min(chunks_lengths)) |
|
chunks[shortest_chunk].append(index) |
|
chunks_lengths[shortest_chunk] += lengths[index] |
|
if len(chunks[shortest_chunk]) == num_indices_per_chunk: |
|
chunks_lengths[shortest_chunk] = float('inf') |
|
|
|
return chunks |
|
|
|
|
|
|
|
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): |
|
|
|
indices = torch.randperm(len(lengths), generator=generator) |
|
megabatch_size = world_size * batch_size |
|
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] |
|
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] |
|
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] |
|
|
|
return [i for megabatch in megabatches for batch in megabatch for i in batch] |
|
|
|
|
|
|
|
class LengthGroupedSampler(Sampler): |
|
r""" |
|
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while |
|
keeping a bit of randomness. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
batch_size: int, |
|
world_size: int, |
|
dataset: Optional[Dataset] = None, |
|
lengths: Optional[List[int]] = None, |
|
model_input_name: Optional[str] = None, |
|
generator=None, |
|
): |
|
if dataset is None and lengths is None: |
|
raise ValueError('One of dataset and lengths must be provided.') |
|
|
|
self.batch_size = batch_size |
|
if lengths is None: |
|
model_input_name = model_input_name if model_input_name is not None else 'input_ids' |
|
if ( |
|
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding)) |
|
or model_input_name not in dataset[0] |
|
): |
|
raise ValueError( |
|
'Can only automatically infer lengths for datasets whose items are dictionaries with an ' |
|
f"'{model_input_name}' key." |
|
) |
|
lengths = [len(feature[model_input_name]) for feature in dataset] |
|
elif isinstance(lengths, torch.Tensor): |
|
logger.info( |
|
'If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]...' |
|
) |
|
lengths = lengths.tolist() |
|
self.world_size = world_size |
|
self.lengths = lengths |
|
self.generator = generator |
|
|
|
def __len__(self): |
|
return len(self.lengths) |
|
|
|
def __iter__(self): |
|
indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) |
|
return iter(indices) |
|
|
|
|
|
|
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: |
|
if self.train_dataset is None or not has_length(self.train_dataset): |
|
return None |
|
|
|
if self.args.group_by_length: |
|
lengths = [] |
|
for dataset in self.train_dataset.datasets: |
|
lengths = lengths + dataset.length |
|
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None |
|
return LengthGroupedSampler( |
|
self.args.train_batch_size, |
|
world_size=self.args.world_size * self.args.gradient_accumulation_steps, |
|
|
|
dataset=self.train_dataset, |
|
lengths=lengths, |
|
model_input_name=model_input_name, |
|
) |
|
else: |
|
return RandomSampler(self.train_dataset) |
|
|
|
|
|
def replace_train_sampler(): |
|
transformers.Trainer._get_train_sampler = _get_train_sampler |
|
|
|
|