Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from . import FairseqDataset | |
class ConcatSentencesDataset(FairseqDataset): | |
def __init__(self, *datasets): | |
super().__init__() | |
self.datasets = datasets | |
assert all( | |
len(ds) == len(datasets[0]) for ds in datasets | |
), "datasets must have the same length" | |
def __getitem__(self, index): | |
return torch.cat([ds[index] for ds in self.datasets]) | |
def __len__(self): | |
return len(self.datasets[0]) | |
def collater(self, samples): | |
return self.datasets[0].collater(samples) | |
def sizes(self): | |
return sum(ds.sizes for ds in self.datasets) | |
def num_tokens(self, index): | |
return sum(ds.num_tokens(index) for ds in self.datasets) | |
def size(self, index): | |
return sum(ds.size(index) for ds in self.datasets) | |
def ordered_indices(self): | |
return self.datasets[0].ordered_indices() | |
def supports_prefetch(self): | |
return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets) | |
def prefetch(self, indices): | |
for ds in self.datasets: | |
if getattr(ds, "supports_prefetch", False): | |
ds.prefetch(indices) | |
def set_epoch(self, epoch): | |
super().set_epoch(epoch) | |
for ds in self.datasets: | |
if hasattr(ds, "set_epoch"): | |
ds.set_epoch(epoch) | |