Spaces:
Paused
Paused
from diffusers import AutoencoderKL | |
from typing import Optional, Union | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKLOutput | |
from diffusers.models.autoencoders.vae import DecoderOutput | |
class PixelMixer(nn.Module): | |
def __init__(self, in_channels, downscale_factor): | |
super(PixelMixer, self).__init__() | |
self.downscale_factor = downscale_factor | |
self.in_channels = in_channels | |
def forward(self, x): | |
latent = self.encode(x) | |
out = self.decode(latent) | |
return out | |
def encode(self, x): | |
return torch.nn.PixelUnshuffle(self.downscale_factor)(x) | |
def decode(self, x): | |
return torch.nn.PixelShuffle(self.downscale_factor)(x) | |
# for reference | |
# none of this matters with llvae, but we need to match the interface (latent_channels might matter) | |
class Config: | |
in_channels = 3 | |
out_channels = 3 | |
down_block_types = ('1', '1', | |
'1', '1') | |
up_block_types = ('1', '1', | |
'1', '1') | |
block_out_channels = (1, 1, 1, 1) | |
latent_channels = 192 # usually 4 | |
norm_num_groups = 32 | |
sample_size = 512 | |
# scaling_factor = 1 | |
# shift_factor = 0 | |
scaling_factor = 1.8 | |
shift_factor = -0.123 | |
# VAE | |
# - Mean: -0.12306906282901764 | |
# - Std: 0.556016206741333 | |
# Normalization parameters: | |
# - Shift factor: -0.12306906282901764 | |
# - Scaling factor: 1.7985087266803625 | |
def __getitem__(cls, x): | |
return getattr(cls, x) | |
class AutoencoderPixelMixer(nn.Module): | |
def __init__(self, in_channels=3, downscale_factor=8): | |
super().__init__() | |
self.mixer = PixelMixer(in_channels, downscale_factor) | |
self._dtype = torch.float32 | |
self._device = torch.device( | |
"cuda" if torch.cuda.is_available() else "cpu") | |
self.config = Config() | |
if downscale_factor == 8: | |
# we go by len of block out channels in code, so simulate it | |
self.config.block_out_channels = (1, 1, 1, 1) | |
self.config.latent_channels = 192 | |
elif downscale_factor == 16: | |
# we go by len of block out channels in code, so simulate it | |
self.config.block_out_channels = (1, 1, 1, 1, 1) | |
self.config.latent_channels = 768 | |
else: | |
raise ValueError( | |
f"downscale_factor {downscale_factor} not supported") | |
def dtype(self): | |
return self._dtype | |
def dtype(self, value): | |
self._dtype = value | |
def device(self): | |
return self._device | |
def device(self, value): | |
self._device = value | |
# mimic to from torch | |
def to(self, *args, **kwargs): | |
# pull out dtype and device if they exist | |
if 'dtype' in kwargs: | |
self._dtype = kwargs['dtype'] | |
if 'device' in kwargs: | |
self._device = kwargs['device'] | |
return super().to(*args, **kwargs) | |
def enable_xformers_memory_efficient_attention(self): | |
pass | |
# @apply_forward_hook | |
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: | |
h = self.mixer.encode(x) | |
# moments = self.quant_conv(h) | |
# posterior = DiagonalGaussianDistribution(moments) | |
if not return_dict: | |
return (h,) | |
class FakeDist: | |
def __init__(self, x): | |
self._sample = x | |
def sample(self): | |
return self._sample | |
return AutoencoderKLOutput(latent_dist=FakeDist(h)) | |
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: | |
dec = self.mixer.decode(z) | |
if not return_dict: | |
return (dec,) | |
return DecoderOutput(sample=dec) | |
# @apply_forward_hook | |
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: | |
decoded = self._decode(z).sample | |
if not return_dict: | |
return (decoded,) | |
return DecoderOutput(sample=decoded) | |
def _set_gradient_checkpointing(self, module, value=False): | |
pass | |
def enable_tiling(self, use_tiling: bool = True): | |
pass | |
def disable_tiling(self): | |
pass | |
def enable_slicing(self): | |
pass | |
def disable_slicing(self): | |
pass | |
def set_use_memory_efficient_attention_xformers(self, value: bool = True): | |
pass | |
def forward( | |
self, | |
sample: torch.FloatTensor, | |
sample_posterior: bool = False, | |
return_dict: bool = True, | |
generator: Optional[torch.Generator] = None, | |
) -> Union[DecoderOutput, torch.FloatTensor]: | |
x = sample | |
posterior = self.encode(x).latent_dist | |
if sample_posterior: | |
z = posterior.sample(generator=generator) | |
else: | |
z = posterior.mode() | |
dec = self.decode(z).sample | |
if not return_dict: | |
return (dec,) | |
return DecoderOutput(sample=dec) | |
# test it | |
if __name__ == '__main__': | |
import os | |
from PIL import Image | |
import torchvision.transforms as transforms | |
user_path = os.path.expanduser('~') | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
dtype = torch.float32 | |
input_path = os.path.join(user_path, "Pictures/test/test.jpg") | |
output_path = os.path.join(user_path, "Pictures/test/test.jpg") | |
img = Image.open(input_path) | |
img_tensor = transforms.ToTensor()(img) | |
img_tensor = img_tensor.unsqueeze(0).to(device=device, dtype=dtype) | |
print("input_shape: ", list(img_tensor.shape)) | |
vae = PixelMixer(in_channels=3, downscale_factor=8) | |
latent = vae.encode(img_tensor) | |
print("latent_shape: ", list(latent.shape)) | |
out_tensor = vae.decode(latent) | |
print("out_shape: ", list(out_tensor.shape)) | |
mse_loss = nn.MSELoss() | |
mse = mse_loss(img_tensor, out_tensor) | |
print("roundtrip_loss: ", mse.item()) | |
out_img = transforms.ToPILImage()(out_tensor.squeeze(0)) | |
out_img.save(output_path) | |