|
|
|
|
|
import numpy as np |
|
import os |
|
import pytorch_lightning as pl |
|
import torch |
|
import webdataset as wds |
|
from torchvision.transforms import transforms |
|
|
|
from ldm.util import instantiate_from_config |
|
|
|
|
|
def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True): |
|
"""Take a list of samples (as dictionary) and create a batch, preserving the keys. |
|
If `tensors` is True, `ndarray` objects are combined into |
|
tensor batches. |
|
:param dict samples: list of samples |
|
:param bool tensors: whether to turn lists of ndarrays into a single ndarray |
|
:returns: single sample consisting of a batch |
|
:rtype: dict |
|
""" |
|
keys = set.intersection(*[set(sample.keys()) for sample in samples]) |
|
batched = {key: [] for key in keys} |
|
|
|
for s in samples: |
|
[batched[key].append(s[key]) for key in batched] |
|
|
|
result = {} |
|
for key in batched: |
|
if isinstance(batched[key][0], (int, float)): |
|
if combine_scalars: |
|
result[key] = np.array(list(batched[key])) |
|
elif isinstance(batched[key][0], torch.Tensor): |
|
if combine_tensors: |
|
result[key] = torch.stack(list(batched[key])) |
|
elif isinstance(batched[key][0], np.ndarray): |
|
if combine_tensors: |
|
result[key] = np.array(list(batched[key])) |
|
else: |
|
result[key] = list(batched[key]) |
|
return result |
|
|
|
|
|
class WebDataModuleFromConfig(pl.LightningDataModule): |
|
|
|
def __init__(self, |
|
tar_base, |
|
batch_size, |
|
train=None, |
|
validation=None, |
|
test=None, |
|
num_workers=4, |
|
multinode=True, |
|
min_size=None, |
|
max_pwatermark=1.0, |
|
**kwargs): |
|
super().__init__() |
|
print(f'Setting tar base to {tar_base}') |
|
self.tar_base = tar_base |
|
self.batch_size = batch_size |
|
self.num_workers = num_workers |
|
self.train = train |
|
self.validation = validation |
|
self.test = test |
|
self.multinode = multinode |
|
self.min_size = min_size |
|
self.max_pwatermark = max_pwatermark |
|
|
|
def make_loader(self, dataset_config): |
|
image_transforms = [instantiate_from_config(tt) for tt in dataset_config.image_transforms] |
|
image_transforms = transforms.Compose(image_transforms) |
|
|
|
process = instantiate_from_config(dataset_config['process']) |
|
|
|
shuffle = dataset_config.get('shuffle', 0) |
|
shardshuffle = shuffle > 0 |
|
|
|
nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only |
|
|
|
tars = os.path.join(self.tar_base, dataset_config.shards) |
|
|
|
dset = wds.WebDataset( |
|
tars, nodesplitter=nodesplitter, shardshuffle=shardshuffle, |
|
handler=wds.warn_and_continue).repeat().shuffle(shuffle) |
|
print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.') |
|
|
|
dset = ( |
|
dset.select(self.filter_keys).decode('pil', |
|
handler=wds.warn_and_continue).select(self.filter_size).map_dict( |
|
jpg=image_transforms, handler=wds.warn_and_continue).map(process)) |
|
dset = (dset.batched(self.batch_size, partial=False, collation_fn=dict_collation_fn)) |
|
|
|
loader = wds.WebLoader(dset, batch_size=None, shuffle=False, num_workers=self.num_workers) |
|
|
|
return loader |
|
|
|
def filter_size(self, x): |
|
if self.min_size is None: |
|
return True |
|
try: |
|
return x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size and x[ |
|
'json']['pwatermark'] <= self.max_pwatermark |
|
except Exception: |
|
return False |
|
|
|
def filter_keys(self, x): |
|
try: |
|
return ("jpg" in x) and ("txt" in x) |
|
except Exception: |
|
return False |
|
|
|
def train_dataloader(self): |
|
return self.make_loader(self.train) |
|
|
|
def val_dataloader(self): |
|
return None |
|
|
|
def test_dataloader(self): |
|
return None |
|
|
|
|
|
if __name__ == '__main__': |
|
from omegaconf import OmegaConf |
|
config = OmegaConf.load("configs/stable-diffusion/train_canny_sd_v1.yaml") |
|
datamod = WebDataModuleFromConfig(**config["data"]["params"]) |
|
dataloader = datamod.train_dataloader() |
|
|
|
for batch in dataloader: |
|
print(batch.keys()) |
|
print(batch['jpg'].shape) |
|
|