from diffusers import AutoencoderKL from typing import Optional, Union import torch import torch.nn as nn import numpy as np from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKLOutput from diffusers.models.autoencoders.vae import DecoderOutput class PixelMixer(nn.Module): def __init__(self, in_channels, downscale_factor): super(PixelMixer, self).__init__() self.downscale_factor = downscale_factor self.in_channels = in_channels def forward(self, x): latent = self.encode(x) out = self.decode(latent) return out def encode(self, x): return torch.nn.PixelUnshuffle(self.downscale_factor)(x) def decode(self, x): return torch.nn.PixelShuffle(self.downscale_factor)(x) # for reference # none of this matters with llvae, but we need to match the interface (latent_channels might matter) class Config: in_channels = 3 out_channels = 3 down_block_types = ('1', '1', '1', '1') up_block_types = ('1', '1', '1', '1') block_out_channels = (1, 1, 1, 1) latent_channels = 192 # usually 4 norm_num_groups = 32 sample_size = 512 # scaling_factor = 1 # shift_factor = 0 scaling_factor = 1.8 shift_factor = -0.123 # VAE # - Mean: -0.12306906282901764 # - Std: 0.556016206741333 # Normalization parameters: # - Shift factor: -0.12306906282901764 # - Scaling factor: 1.7985087266803625 def __getitem__(cls, x): return getattr(cls, x) class AutoencoderPixelMixer(nn.Module): def __init__(self, in_channels=3, downscale_factor=8): super().__init__() self.mixer = PixelMixer(in_channels, downscale_factor) self._dtype = torch.float32 self._device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.config = Config() if downscale_factor == 8: # we go by len of block out channels in code, so simulate it self.config.block_out_channels = (1, 1, 1, 1) self.config.latent_channels = 192 elif downscale_factor == 16: # we go by len of block out channels in code, so simulate it self.config.block_out_channels = (1, 1, 1, 1, 1) self.config.latent_channels = 768 else: raise ValueError( f"downscale_factor {downscale_factor} not supported") @property def dtype(self): return self._dtype @dtype.setter def dtype(self, value): self._dtype = value @property def device(self): return self._device @device.setter def device(self, value): self._device = value # mimic to from torch def to(self, *args, **kwargs): # pull out dtype and device if they exist if 'dtype' in kwargs: self._dtype = kwargs['dtype'] if 'device' in kwargs: self._device = kwargs['device'] return super().to(*args, **kwargs) def enable_xformers_memory_efficient_attention(self): pass # @apply_forward_hook def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: h = self.mixer.encode(x) # moments = self.quant_conv(h) # posterior = DiagonalGaussianDistribution(moments) if not return_dict: return (h,) class FakeDist: def __init__(self, x): self._sample = x def sample(self): return self._sample return AutoencoderKLOutput(latent_dist=FakeDist(h)) def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: dec = self.mixer.decode(z) if not return_dict: return (dec,) return DecoderOutput(sample=dec) # @apply_forward_hook def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: decoded = self._decode(z).sample if not return_dict: return (decoded,) return DecoderOutput(sample=decoded) def _set_gradient_checkpointing(self, module, value=False): pass def enable_tiling(self, use_tiling: bool = True): pass def disable_tiling(self): pass def enable_slicing(self): pass def disable_slicing(self): pass def set_use_memory_efficient_attention_xformers(self, value: bool = True): pass def forward( self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, ) -> Union[DecoderOutput, torch.FloatTensor]: x = sample posterior = self.encode(x).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() dec = self.decode(z).sample if not return_dict: return (dec,) return DecoderOutput(sample=dec) # test it if __name__ == '__main__': import os from PIL import Image import torchvision.transforms as transforms user_path = os.path.expanduser('~') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 input_path = os.path.join(user_path, "Pictures/test/test.jpg") output_path = os.path.join(user_path, "Pictures/test/test.jpg") img = Image.open(input_path) img_tensor = transforms.ToTensor()(img) img_tensor = img_tensor.unsqueeze(0).to(device=device, dtype=dtype) print("input_shape: ", list(img_tensor.shape)) vae = PixelMixer(in_channels=3, downscale_factor=8) latent = vae.encode(img_tensor) print("latent_shape: ", list(latent.shape)) out_tensor = vae.decode(latent) print("out_shape: ", list(out_tensor.shape)) mse_loss = nn.MSELoss() mse = mse_loss(img_tensor, out_tensor) print("roundtrip_loss: ", mse.item()) out_img = transforms.ToPILImage()(out_tensor.squeeze(0)) out_img.save(output_path)