from typing import Any import torch import copy import lightning.pytorch as pl from lightning.pytorch.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS from torch.utils.data import DataLoader from src.data.dataset.randn import RandomNDataset from src.data.var_training import VARTransformEngine def collate_fn(batch): new_batch = copy.deepcopy(batch) new_batch = list(zip(*new_batch)) for i in range(len(new_batch)): if isinstance(new_batch[i][0], torch.Tensor): try: new_batch[i] = torch.stack(new_batch[i], dim=0) except: print("Warning: could not stack tensors") return new_batch class DataModule(pl.LightningDataModule): def __init__(self, train_root, test_nature_root, test_gen_root, train_image_size=64, train_batch_size=64, train_num_workers=8, var_transform_engine: VARTransformEngine = None, train_prefetch_factor=2, train_dataset: str = None, eval_batch_size=32, eval_num_workers=4, eval_max_num_instances=50000, pred_batch_size=32, pred_num_workers=4, pred_seeds:str=None, pred_selected_classes=None, num_classes=1000, latent_shape=(4,64,64), ): super().__init__() pred_seeds = list(map(lambda x: int(x), pred_seeds.strip().split(","))) if pred_seeds is not None else None self.train_root = train_root self.train_image_size = train_image_size self.train_dataset = train_dataset # stupid data_convert override, just to make nebular happy self.train_batch_size = train_batch_size self.train_num_workers = train_num_workers self.train_prefetch_factor = train_prefetch_factor self.test_nature_root = test_nature_root self.test_gen_root = test_gen_root self.eval_max_num_instances = eval_max_num_instances self.pred_seeds = pred_seeds self.num_classes = num_classes self.latent_shape = latent_shape self.eval_batch_size = eval_batch_size self.pred_batch_size = pred_batch_size self.pred_num_workers = pred_num_workers self.eval_num_workers = eval_num_workers self.pred_selected_classes = pred_selected_classes self._train_dataloader = None self.var_transform_engine = var_transform_engine def setup(self, stage: str) -> None: if stage == "fit": assert self.train_dataset is not None if self.train_dataset == "pix_imagenet64": from src.data.dataset.imagenet import PixImageNet64 self.train_dataset = PixImageNet64( root=self.train_root, ) elif self.train_dataset == "pix_imagenet128": from src.data.dataset.imagenet import PixImageNet128 self.train_dataset = PixImageNet128( root=self.train_root, ) elif self.train_dataset == "imagenet256": from src.data.dataset.imagenet import ImageNet256 self.train_dataset = ImageNet256( root=self.train_root, ) elif self.train_dataset == "pix_imagenet256": from src.data.dataset.imagenet import PixImageNet256 self.train_dataset = PixImageNet256( root=self.train_root, ) elif self.train_dataset == "imagenet512": from src.data.dataset.imagenet import ImageNet512 self.train_dataset = ImageNet512( root=self.train_root, ) elif self.train_dataset == "pix_imagenet512": from src.data.dataset.imagenet import PixImageNet512 self.train_dataset = PixImageNet512( root=self.train_root, ) else: raise NotImplementedError("no such dataset") def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: if self.var_transform_engine and self.trainer.training: batch = self.var_transform_engine(batch) return batch def train_dataloader(self) -> TRAIN_DATALOADERS: global_rank = self.trainer.global_rank world_size = self.trainer.world_size from torch.utils.data import DistributedSampler sampler = DistributedSampler(self.train_dataset, num_replicas=world_size, rank=global_rank, shuffle=True) self._train_dataloader = DataLoader( self.train_dataset, self.train_batch_size, timeout=6000, num_workers=self.train_num_workers, prefetch_factor=self.train_prefetch_factor, sampler=sampler, collate_fn=collate_fn, ) return self._train_dataloader def val_dataloader(self) -> EVAL_DATALOADERS: global_rank = self.trainer.global_rank world_size = self.trainer.world_size self.eval_dataset = RandomNDataset( latent_shape=self.latent_shape, num_classes=self.num_classes, max_num_instances=self.eval_max_num_instances, ) from torch.utils.data import DistributedSampler sampler = DistributedSampler(self.eval_dataset, num_replicas=world_size, rank=global_rank, shuffle=False) return DataLoader(self.eval_dataset, self.eval_batch_size, num_workers=self.eval_num_workers, prefetch_factor=2, collate_fn=collate_fn, sampler=sampler ) def predict_dataloader(self) -> EVAL_DATALOADERS: global_rank = self.trainer.global_rank world_size = self.trainer.world_size self.pred_dataset = RandomNDataset( seeds= self.pred_seeds, max_num_instances=50000, num_classes=self.num_classes, selected_classes=self.pred_selected_classes, latent_shape=self.latent_shape, ) from torch.utils.data import DistributedSampler sampler = DistributedSampler(self.pred_dataset, num_replicas=world_size, rank=global_rank, shuffle=False) return DataLoader(self.pred_dataset, batch_size=self.pred_batch_size, num_workers=self.pred_num_workers, prefetch_factor=4, collate_fn=collate_fn, sampler=sampler )