LoRa_Streamlit / ai-toolkit /toolkit /pixel_shuffle_encoder.py
ramimu's picture
Upload 586 files
1c72248 verified
from diffusers import AutoencoderKL
from typing import Optional, Union
import torch
import torch.nn as nn
import numpy as np
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKLOutput
from diffusers.models.autoencoders.vae import DecoderOutput
class PixelMixer(nn.Module):
def __init__(self, in_channels, downscale_factor):
super(PixelMixer, self).__init__()
self.downscale_factor = downscale_factor
self.in_channels = in_channels
def forward(self, x):
latent = self.encode(x)
out = self.decode(latent)
return out
def encode(self, x):
return torch.nn.PixelUnshuffle(self.downscale_factor)(x)
def decode(self, x):
return torch.nn.PixelShuffle(self.downscale_factor)(x)
# for reference
# none of this matters with llvae, but we need to match the interface (latent_channels might matter)
class Config:
in_channels = 3
out_channels = 3
down_block_types = ('1', '1',
'1', '1')
up_block_types = ('1', '1',
'1', '1')
block_out_channels = (1, 1, 1, 1)
latent_channels = 192 # usually 4
norm_num_groups = 32
sample_size = 512
# scaling_factor = 1
# shift_factor = 0
scaling_factor = 1.8
shift_factor = -0.123
# VAE
# - Mean: -0.12306906282901764
# - Std: 0.556016206741333
# Normalization parameters:
# - Shift factor: -0.12306906282901764
# - Scaling factor: 1.7985087266803625
def __getitem__(cls, x):
return getattr(cls, x)
class AutoencoderPixelMixer(nn.Module):
def __init__(self, in_channels=3, downscale_factor=8):
super().__init__()
self.mixer = PixelMixer(in_channels, downscale_factor)
self._dtype = torch.float32
self._device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
self.config = Config()
if downscale_factor == 8:
# we go by len of block out channels in code, so simulate it
self.config.block_out_channels = (1, 1, 1, 1)
self.config.latent_channels = 192
elif downscale_factor == 16:
# we go by len of block out channels in code, so simulate it
self.config.block_out_channels = (1, 1, 1, 1, 1)
self.config.latent_channels = 768
else:
raise ValueError(
f"downscale_factor {downscale_factor} not supported")
@property
def dtype(self):
return self._dtype
@dtype.setter
def dtype(self, value):
self._dtype = value
@property
def device(self):
return self._device
@device.setter
def device(self, value):
self._device = value
# mimic to from torch
def to(self, *args, **kwargs):
# pull out dtype and device if they exist
if 'dtype' in kwargs:
self._dtype = kwargs['dtype']
if 'device' in kwargs:
self._device = kwargs['device']
return super().to(*args, **kwargs)
def enable_xformers_memory_efficient_attention(self):
pass
# @apply_forward_hook
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
h = self.mixer.encode(x)
# moments = self.quant_conv(h)
# posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (h,)
class FakeDist:
def __init__(self, x):
self._sample = x
def sample(self):
return self._sample
return AutoencoderKLOutput(latent_dist=FakeDist(h))
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
dec = self.mixer.decode(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
# @apply_forward_hook
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def _set_gradient_checkpointing(self, module, value=False):
pass
def enable_tiling(self, use_tiling: bool = True):
pass
def disable_tiling(self):
pass
def enable_slicing(self):
pass
def disable_slicing(self):
pass
def set_use_memory_efficient_attention_xformers(self, value: bool = True):
pass
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
# test it
if __name__ == '__main__':
import os
from PIL import Image
import torchvision.transforms as transforms
user_path = os.path.expanduser('~')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
input_path = os.path.join(user_path, "Pictures/test/test.jpg")
output_path = os.path.join(user_path, "Pictures/test/test.jpg")
img = Image.open(input_path)
img_tensor = transforms.ToTensor()(img)
img_tensor = img_tensor.unsqueeze(0).to(device=device, dtype=dtype)
print("input_shape: ", list(img_tensor.shape))
vae = PixelMixer(in_channels=3, downscale_factor=8)
latent = vae.encode(img_tensor)
print("latent_shape: ", list(latent.shape))
out_tensor = vae.decode(latent)
print("out_shape: ", list(out_tensor.shape))
mse_loss = nn.MSELoss()
mse = mse_loss(img_tensor, out_tensor)
print("roundtrip_loss: ", mse.item())
out_img = transforms.ToPILImage()(out_tensor.squeeze(0))
out_img.save(output_path)