old_tok / paintmind /utils /datasets.py
tennant's picture
upload
af7c0ce
raw
history blame contribute delete
3.04 kB
import os
import torch
import torchvision
import numpy as np
import os.path as osp
from glob import glob
from PIL import Image
import torchvision
import torchvision.transforms as TF
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
def vae_transforms(image_set, aug='randcrop', img_size=256):
t = []
if image_set == 'train':
if aug == 'randcrop':
t.append(TF.Resize(img_size, interpolation=TF.InterpolationMode.BICUBIC, antialias=True))
t.append(TF.RandomCrop(img_size))
elif aug == 'centercrop':
t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size)))
else:
raise ValueError(f"Invalid augmentation: {aug}")
t.append(TF.RandomHorizontalFlip(p=0.5))
else:
t.append(TF.Resize(img_size, interpolation=TF.InterpolationMode.BICUBIC, antialias=True))
t.append(TF.CenterCrop(img_size))
t.append(TF.ToTensor())
return TF.Compose(t)
def cached_transforms(aug='tencrop', img_size=256, crop_ranges=[1.05, 1.10]):
t = []
if 'centercrop' in aug:
t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size)))
t.append(TF.Lambda(lambda x: torch.stack([TF.ToTensor()(x), TF.ToTensor()(TF.functional.hflip(x))])))
elif 'tencrop' in aug:
crop_sizes = [int(img_size * crop_range) for crop_range in crop_ranges]
t.append(TF.Lambda(lambda x: [center_crop_arr(x, crop_size) for crop_size in crop_sizes]))
t.append(TF.Lambda(lambda crops: [crop for crop_tuple in [TF.TenCrop(img_size)(crop) for crop in crops] for crop in crop_tuple]))
t.append(TF.Lambda(lambda crops: torch.stack([TF.ToTensor()(crop) for crop in crops])))
else:
raise ValueError(f"Invalid augmentation: {aug}")
return TF.Compose(t)
class ImageNet(torchvision.datasets.ImageFolder):
def __init__(self, root, split='train', aug='randcrop', img_size=256):
super().__init__(osp.join(root, split))
if not 'cache' in aug:
self.transform = vae_transforms(split, aug=aug, img_size=img_size)
else:
self.transform = cached_transforms(aug=aug, img_size=img_size)