import torch from typing import Callable from src.diffusion.base.training import * from src.diffusion.base.scheduling import BaseScheduler import concurrent.futures from concurrent.futures import ProcessPoolExecutor from typing import List from PIL import Image import torch import random import numpy as np import copy import torchvision.transforms.functional as tvtf from src.models.vae import uint82fp def center_crop_arr(pil_image, width, height): """ Center cropping implementation from ADM. https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 """ while pil_image.size[0] >= 2 * width and pil_image.size[1] >= 2 * height: pil_image = pil_image.resize( tuple(x // 2 for x in pil_image.size), resample=Image.BOX ) scale = max(width / pil_image.size[0], height / pil_image.size[1]) pil_image = pil_image.resize( tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC ) arr = np.array(pil_image) crop_y = random.randint(0, (arr.shape[0] - height)) crop_x = random.randint(0, (arr.shape[1] - width)) return Image.fromarray(arr[crop_y: crop_y + height, crop_x: crop_x + width]) def process_fn(width, height, data, hflip=0.5): image, label = data if random.uniform(0, 1) > hflip: # hflip image = tvtf.hflip(image) image = center_crop_arr(image, width, height) # crop image = np.array(image).transpose(2, 0, 1) return image, label class VARCandidate: def __init__(self, aspect_ratio, width, height, buffer, max_buffer_size=1024): self.aspect_ratio = aspect_ratio self.width = int(width) self.height = int(height) self.buffer = buffer self.max_buffer_size = max_buffer_size def add_sample(self, data): self.buffer.append(data) self.buffer = self.buffer[-self.max_buffer_size:] def ready(self, batch_size): return len(self.buffer) >= batch_size def get_batch(self, batch_size): batch = self.buffer[:batch_size] self.buffer = self.buffer[batch_size:] batch = [copy.deepcopy(b.result()) for b in batch] x, y = zip(*batch) x = torch.stack([torch.from_numpy(im).cuda() for im in x], dim=0) x = list(map(uint82fp, x)) return x, y class VARTransformEngine: def __init__(self, base_image_size, num_aspect_ratios, min_aspect_ratio, max_aspect_ratio, num_workers = 8, ): self.base_image_size = base_image_size self.num_aspect_ratios = num_aspect_ratios self.min_aspect_ratio = min_aspect_ratio self.max_aspect_ratio = max_aspect_ratio self.aspect_ratios = np.linspace(self.min_aspect_ratio, self.max_aspect_ratio, self.num_aspect_ratios) self.aspect_ratios = self.aspect_ratios.tolist() self.candidates_pool = [] for i in range(self.num_aspect_ratios): candidate = VARCandidate( aspect_ratio=self.aspect_ratios[i], width=int(self.base_image_size * self.aspect_ratios[i] ** 0.5 // 16 * 16), height=int(self.base_image_size * self.aspect_ratios[i] ** -0.5 // 16 * 16), buffer=[], max_buffer_size=1024 ) self.candidates_pool.append(candidate) self.default_candidate = VARCandidate( aspect_ratio=1.0, width=self.base_image_size, height=self.base_image_size, buffer=[], max_buffer_size=1024, ) self.executor_pool = ProcessPoolExecutor(max_workers=num_workers) self._prefill_count = 100 def find_candidate(self, data): image = data[0] aspect_ratio = image.size[0] / image.size[1] min_distance = 1000000 min_candidate = None for candidate in self.candidates_pool: dis = abs(aspect_ratio - candidate.aspect_ratio) if dis < min_distance: min_distance = dis min_candidate = candidate return min_candidate def __call__(self, batch_data): self._prefill_count -= 1 if isinstance(batch_data[0], torch.Tensor): batch_data[0] = batch_data[0].unbind(0) batch_data = list(zip(*batch_data)) for data in batch_data: candidate = self.find_candidate(data) future = self.executor_pool.submit(process_fn, candidate.width, candidate.height, data) candidate.add_sample(future) if self._prefill_count >= 0: future = self.executor_pool.submit(process_fn, self.default_candidate.width, self.default_candidate.height, data) self.default_candidate.add_sample(future) batch_size = len(batch_data) random.shuffle(self.candidates_pool) for candidate in self.candidates_pool: if candidate.ready(batch_size=batch_size): return candidate.get_batch(batch_size=batch_size) # fallback to default 256 for data in batch_data: future = self.executor_pool.submit(process_fn, self.default_candidate.width, self.default_candidate.height, data) self.default_candidate.add_sample(future) return self.default_candidate.get_batch(batch_size=batch_size)