Spaces:
Running
on
Zero
Running
on
Zero
from typing import Any | |
import torch | |
import copy | |
import lightning.pytorch as pl | |
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS | |
from torch.utils.data import DataLoader | |
from src.data.dataset.randn import RandomNDataset | |
from src.data.var_training import VARTransformEngine | |
def collate_fn(batch): | |
new_batch = copy.deepcopy(batch) | |
new_batch = list(zip(*new_batch)) | |
for i in range(len(new_batch)): | |
if isinstance(new_batch[i][0], torch.Tensor): | |
try: | |
new_batch[i] = torch.stack(new_batch[i], dim=0) | |
except: | |
print("Warning: could not stack tensors") | |
return new_batch | |
class DataModule(pl.LightningDataModule): | |
def __init__(self, | |
train_root, | |
test_nature_root, | |
test_gen_root, | |
train_image_size=64, | |
train_batch_size=64, | |
train_num_workers=8, | |
var_transform_engine: VARTransformEngine = None, | |
train_prefetch_factor=2, | |
train_dataset: str = None, | |
eval_batch_size=32, | |
eval_num_workers=4, | |
eval_max_num_instances=50000, | |
pred_batch_size=32, | |
pred_num_workers=4, | |
pred_seeds:str=None, | |
pred_selected_classes=None, | |
num_classes=1000, | |
latent_shape=(4,64,64), | |
): | |
super().__init__() | |
pred_seeds = list(map(lambda x: int(x), pred_seeds.strip().split(","))) if pred_seeds is not None else None | |
self.train_root = train_root | |
self.train_image_size = train_image_size | |
self.train_dataset = train_dataset | |
# stupid data_convert override, just to make nebular happy | |
self.train_batch_size = train_batch_size | |
self.train_num_workers = train_num_workers | |
self.train_prefetch_factor = train_prefetch_factor | |
self.test_nature_root = test_nature_root | |
self.test_gen_root = test_gen_root | |
self.eval_max_num_instances = eval_max_num_instances | |
self.pred_seeds = pred_seeds | |
self.num_classes = num_classes | |
self.latent_shape = latent_shape | |
self.eval_batch_size = eval_batch_size | |
self.pred_batch_size = pred_batch_size | |
self.pred_num_workers = pred_num_workers | |
self.eval_num_workers = eval_num_workers | |
self.pred_selected_classes = pred_selected_classes | |
self._train_dataloader = None | |
self.var_transform_engine = var_transform_engine | |
def setup(self, stage: str) -> None: | |
if stage == "fit": | |
assert self.train_dataset is not None | |
if self.train_dataset == "pix_imagenet64": | |
from src.data.dataset.imagenet import PixImageNet64 | |
self.train_dataset = PixImageNet64( | |
root=self.train_root, | |
) | |
elif self.train_dataset == "pix_imagenet128": | |
from src.data.dataset.imagenet import PixImageNet128 | |
self.train_dataset = PixImageNet128( | |
root=self.train_root, | |
) | |
elif self.train_dataset == "imagenet256": | |
from src.data.dataset.imagenet import ImageNet256 | |
self.train_dataset = ImageNet256( | |
root=self.train_root, | |
) | |
elif self.train_dataset == "pix_imagenet256": | |
from src.data.dataset.imagenet import PixImageNet256 | |
self.train_dataset = PixImageNet256( | |
root=self.train_root, | |
) | |
elif self.train_dataset == "imagenet512": | |
from src.data.dataset.imagenet import ImageNet512 | |
self.train_dataset = ImageNet512( | |
root=self.train_root, | |
) | |
elif self.train_dataset == "pix_imagenet512": | |
from src.data.dataset.imagenet import PixImageNet512 | |
self.train_dataset = PixImageNet512( | |
root=self.train_root, | |
) | |
else: | |
raise NotImplementedError("no such dataset") | |
def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: | |
if self.var_transform_engine and self.trainer.training: | |
batch = self.var_transform_engine(batch) | |
return batch | |
def train_dataloader(self) -> TRAIN_DATALOADERS: | |
global_rank = self.trainer.global_rank | |
world_size = self.trainer.world_size | |
from torch.utils.data import DistributedSampler | |
sampler = DistributedSampler(self.train_dataset, num_replicas=world_size, rank=global_rank, shuffle=True) | |
self._train_dataloader = DataLoader( | |
self.train_dataset, | |
self.train_batch_size, | |
timeout=6000, | |
num_workers=self.train_num_workers, | |
prefetch_factor=self.train_prefetch_factor, | |
sampler=sampler, | |
collate_fn=collate_fn, | |
) | |
return self._train_dataloader | |
def val_dataloader(self) -> EVAL_DATALOADERS: | |
global_rank = self.trainer.global_rank | |
world_size = self.trainer.world_size | |
self.eval_dataset = RandomNDataset( | |
latent_shape=self.latent_shape, | |
num_classes=self.num_classes, | |
max_num_instances=self.eval_max_num_instances, | |
) | |
from torch.utils.data import DistributedSampler | |
sampler = DistributedSampler(self.eval_dataset, num_replicas=world_size, rank=global_rank, shuffle=False) | |
return DataLoader(self.eval_dataset, self.eval_batch_size, | |
num_workers=self.eval_num_workers, | |
prefetch_factor=2, | |
collate_fn=collate_fn, | |
sampler=sampler | |
) | |
def predict_dataloader(self) -> EVAL_DATALOADERS: | |
global_rank = self.trainer.global_rank | |
world_size = self.trainer.world_size | |
self.pred_dataset = RandomNDataset( | |
seeds= self.pred_seeds, | |
max_num_instances=50000, | |
num_classes=self.num_classes, | |
selected_classes=self.pred_selected_classes, | |
latent_shape=self.latent_shape, | |
) | |
from torch.utils.data import DistributedSampler | |
sampler = DistributedSampler(self.pred_dataset, num_replicas=world_size, rank=global_rank, shuffle=False) | |
return DataLoader(self.pred_dataset, batch_size=self.pred_batch_size, | |
num_workers=self.pred_num_workers, | |
prefetch_factor=4, | |
collate_fn=collate_fn, | |
sampler=sampler | |
) | |