Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from torch.utils.data.dataloader import default_collate | |
from . import FairseqDataset | |
class BaseWrapperDataset(FairseqDataset): | |
def __init__(self, dataset): | |
super().__init__() | |
self.dataset = dataset | |
def __getitem__(self, index): | |
return self.dataset[index] | |
def __len__(self): | |
return len(self.dataset) | |
def collater(self, samples): | |
if hasattr(self.dataset, "collater"): | |
return self.dataset.collater(samples) | |
else: | |
return default_collate(samples) | |
def sizes(self): | |
return self.dataset.sizes | |
def num_tokens(self, index): | |
return self.dataset.num_tokens(index) | |
def size(self, index): | |
return self.dataset.size(index) | |
def ordered_indices(self): | |
return self.dataset.ordered_indices() | |
def supports_prefetch(self): | |
return getattr(self.dataset, "supports_prefetch", False) | |
def attr(self, attr: str, index: int): | |
return self.dataset.attr(attr, index) | |
def prefetch(self, indices): | |
self.dataset.prefetch(indices) | |
def get_batch_shapes(self): | |
return self.dataset.get_batch_shapes() | |
def batch_by_size( | |
self, | |
indices, | |
max_tokens=None, | |
max_sentences=None, | |
required_batch_size_multiple=1, | |
): | |
return self.dataset.batch_by_size( | |
indices, | |
max_tokens=max_tokens, | |
max_sentences=max_sentences, | |
required_batch_size_multiple=required_batch_size_multiple, | |
) | |
def filter_indices_by_size(self, indices, max_sizes): | |
return self.dataset.filter_indices_by_size(indices, max_sizes) | |
def can_reuse_epoch_itr_across_epochs(self): | |
return self.dataset.can_reuse_epoch_itr_across_epochs | |
def set_epoch(self, epoch): | |
super().set_epoch(epoch) | |
if hasattr(self.dataset, "set_epoch"): | |
self.dataset.set_epoch(epoch) | |