|
from functools import partial |
|
import weakref |
|
import torch |
|
import torch.utils.data |
|
|
|
import pointcept.utils.comm as comm |
|
from pointcept.datasets.utils import point_collate_fn |
|
from pointcept.datasets import ConcatDataset |
|
from pointcept.utils.env import set_seed |
|
|
|
|
|
class MultiDatasetDummySampler: |
|
def __init__(self): |
|
self.dataloader = None |
|
|
|
def set_epoch(self, epoch): |
|
if comm.get_world_size() > 1: |
|
for dataloader in self.dataloader.dataloaders: |
|
dataloader.sampler.set_epoch(epoch) |
|
return |
|
|
|
|
|
class MultiDatasetDataloader: |
|
""" |
|
Multiple Datasets Dataloader, batch data from a same dataset and mix up ratio determined by loop of each sub dataset. |
|
The overall length is determined by the main dataset (first) and loop of concat dataset. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
concat_dataset: ConcatDataset, |
|
batch_size_per_gpu: int, |
|
num_worker_per_gpu: int, |
|
mix_prob=0, |
|
seed=None, |
|
): |
|
self.datasets = concat_dataset.datasets |
|
self.ratios = [dataset.loop for dataset in self.datasets] |
|
|
|
for dataset in self.datasets: |
|
dataset.loop = 1 |
|
|
|
self.datasets[0].loop = concat_dataset.loop |
|
|
|
num_workers = num_worker_per_gpu // len(self.datasets) |
|
self.dataloaders = [] |
|
for dataset_id, dataset in enumerate(self.datasets): |
|
if comm.get_world_size() > 1: |
|
sampler = torch.utils.data.distributed.DistributedSampler(dataset) |
|
else: |
|
sampler = None |
|
|
|
init_fn = ( |
|
partial( |
|
self._worker_init_fn, |
|
dataset_id=dataset_id, |
|
num_workers=num_workers, |
|
num_datasets=len(self.datasets), |
|
rank=comm.get_rank(), |
|
seed=seed, |
|
) |
|
if seed is not None |
|
else None |
|
) |
|
self.dataloaders.append( |
|
torch.utils.data.DataLoader( |
|
dataset, |
|
batch_size=batch_size_per_gpu, |
|
shuffle=(sampler is None), |
|
num_workers=num_worker_per_gpu, |
|
sampler=sampler, |
|
collate_fn=partial(point_collate_fn, mix_prob=mix_prob), |
|
pin_memory=True, |
|
worker_init_fn=init_fn, |
|
drop_last=True, |
|
persistent_workers=True, |
|
) |
|
) |
|
self.sampler = MultiDatasetDummySampler() |
|
self.sampler.dataloader = weakref.proxy(self) |
|
|
|
def __iter__(self): |
|
iterator = [iter(dataloader) for dataloader in self.dataloaders] |
|
while True: |
|
for i in range(len(self.ratios)): |
|
for _ in range(self.ratios[i]): |
|
try: |
|
batch = next(iterator[i]) |
|
except StopIteration: |
|
if i == 0: |
|
return |
|
else: |
|
iterator[i] = iter(self.dataloaders[i]) |
|
batch = next(iterator[i]) |
|
yield batch |
|
|
|
def __len__(self): |
|
main_data_loader_length = len(self.dataloaders[0]) |
|
return ( |
|
main_data_loader_length // self.ratios[0] * sum(self.ratios) |
|
+ main_data_loader_length % self.ratios[0] |
|
) |
|
|
|
@staticmethod |
|
def _worker_init_fn(worker_id, num_workers, dataset_id, num_datasets, rank, seed): |
|
worker_seed = ( |
|
num_workers * num_datasets * rank |
|
+ num_workers * dataset_id |
|
+ worker_id |
|
+ seed |
|
) |
|
set_seed(worker_seed) |
|
|