Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch | |
import scipy | |
import torch.nn.functional as F | |
from torch import nn | |
from torch.autograd import Variable | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
""" | |
Helper functions for new types of inverse problems | |
""" | |
def fft2(x): | |
""" FFT with shifting DC to the center of the image""" | |
return torch.fft.fftshift(torch.fft.fft2(x), dim=[-1, -2]) | |
def ifft2(x): | |
""" IFFT with shifting DC to the corner of the image prior to transform""" | |
return torch.fft.ifft2(torch.fft.ifftshift(x, dim=[-1, -2])) | |
def fft2_m(x): | |
""" FFT for multi-coil """ | |
if not torch.is_complex(x): | |
x = x.type(torch.complex64) | |
return torch.view_as_complex(fft2c_new(torch.view_as_real(x))) | |
def ifft2_m(x): | |
""" IFFT for multi-coil """ | |
if not torch.is_complex(x): | |
x = x.type(torch.complex64) | |
return torch.view_as_complex(ifft2c_new(torch.view_as_real(x))) | |
def clear(x): | |
x = x.detach().cpu().squeeze().numpy() | |
return normalize_np(x) | |
def resize_and_crop(image, imsize=512): | |
width, height = image.size | |
if width < height: | |
new_width = imsize | |
new_height = int((imsize / width) * height) | |
else: | |
new_height = imsize | |
new_width = int((imsize / height) * width) | |
image_resized = image.resize((new_width, new_height)) | |
left = (new_width - imsize) / 2 | |
top = (new_height - imsize) / 2 | |
right = (new_width + imsize) / 2 | |
bottom = (new_height + imsize) / 2 | |
image_cropped = image_resized.crop((left, top, right, bottom)) | |
return image_cropped | |
def clear_color(x, normalize=True): | |
if torch.is_complex(x): | |
x = torch.abs(x) | |
if normalize: | |
x = x.detach().cpu().squeeze().numpy() | |
if x.ndim == 3: | |
return normalize_np(np.transpose(x, (1, 2, 0))) | |
else: | |
return normalize_np(x) | |
else: | |
x = (x / 2 + 0.5).clamp(0, 1) | |
x = x.detach().cpu().squeeze().numpy() | |
if x.ndim == 3: | |
return np.transpose(x, (1, 2, 0)) | |
else: | |
return x | |
def normalize_np(img): | |
""" Normalize img in arbitrary range to [0, 1] """ | |
img -= np.min(img) | |
img /= np.max(img) | |
return img | |
def prepare_im(load_dir, image_size, device): | |
ref_img = torch.from_numpy(normalize_np(plt.imread(load_dir)[:, :, :3].astype(np.float32))).to(device) | |
ref_img = ref_img.permute(2, 0, 1) | |
ref_img = ref_img.view(1, 3, image_size, image_size) | |
ref_img = ref_img * 2 - 1 | |
return ref_img | |
def fold_unfold(img_t, kernel, stride): | |
img_shape = img_t.shape | |
B, C, H, W = img_shape | |
print("\n----- input shape: ", img_shape) | |
patches = img_t.unfold(3, kernel, stride).unfold(2, kernel, stride).permute(0, 1, 2, 3, 5, 4) | |
print("\n----- patches shape:", patches.shape) | |
# reshape output to match F.fold input | |
patches = patches.contiguous().view(B, C, -1, kernel * kernel) | |
print("\n", patches.shape) # [B, C, nb_patches_all, kernel_size*kernel_size] | |
patches = patches.permute(0, 1, 3, 2) | |
print("\n", patches.shape) # [B, C, kernel_size*kernel_size, nb_patches_all] | |
patches = patches.contiguous().view(B, C * kernel * kernel, -1) | |
print("\n", patches.shape) # [B, C*prod(kernel_size), L] as expected by Fold | |
output = F.fold(patches, output_size=(H, W), | |
kernel_size=kernel, stride=stride) | |
# mask that mimics the original folding: | |
recovery_mask = F.fold(torch.ones_like(patches), output_size=( | |
H, W), kernel_size=kernel, stride=stride) | |
output = output / recovery_mask | |
return patches, output | |
def reshape_patch(x, crop_size=128, dim_size=3): | |
x = x.transpose(0, 2).squeeze() # [9, 3*(128**2)] | |
x = x.view(dim_size ** 2, 3, crop_size, crop_size) | |
return x | |
def reshape_patch_back(x, crop_size=128, dim_size=3): | |
x = x.view(dim_size ** 2, 3 * (crop_size ** 2)).unsqueeze(dim=-1) | |
x = x.transpose(0, 2) | |
return x | |
class Unfolder: | |
def __init__(self, img_size=256, crop_size=128, stride=64): | |
self.img_size = img_size | |
self.crop_size = crop_size | |
self.stride = stride | |
self.unfold = nn.Unfold(crop_size, stride=stride) | |
self.dim_size = (img_size - crop_size) // stride + 1 | |
def __call__(self, x): | |
patch1D = self.unfold(x) | |
patch2D = reshape_patch(patch1D, crop_size=self.crop_size, dim_size=self.dim_size) | |
return patch2D | |
def center_crop(img, new_width=None, new_height=None): | |
width = img.shape[1] | |
height = img.shape[0] | |
if new_width is None: | |
new_width = min(width, height) | |
if new_height is None: | |
new_height = min(width, height) | |
left = int(np.ceil((width - new_width) / 2)) | |
right = width - int(np.floor((width - new_width) / 2)) | |
top = int(np.ceil((height - new_height) / 2)) | |
bottom = height - int(np.floor((height - new_height) / 2)) | |
if len(img.shape) == 2: | |
center_cropped_img = img[top:bottom, left:right] | |
else: | |
center_cropped_img = img[top:bottom, left:right, ...] | |
return center_cropped_img | |
class Folder: | |
def __init__(self, img_size=256, crop_size=128, stride=64): | |
self.img_size = img_size | |
self.crop_size = crop_size | |
self.stride = stride | |
self.fold = nn.Fold(img_size, crop_size, stride=stride) | |
self.dim_size = (img_size - crop_size) // stride + 1 | |
def __call__(self, patch2D): | |
patch1D = reshape_patch_back(patch2D, crop_size=self.crop_size, dim_size=self.dim_size) | |
return self.fold(patch1D) | |
def random_sq_bbox(img, mask_shape, image_size=256, margin=(16, 16)): | |
"""Generate a random sqaure mask for inpainting | |
""" | |
B, C, H, W = img.shape | |
h, w = mask_shape | |
margin_height, margin_width = margin | |
maxt = image_size - margin_height - h | |
maxl = image_size - margin_width - w | |
# bb | |
t = np.random.randint(margin_height, maxt) | |
l = np.random.randint(margin_width, maxl) | |
# make mask | |
mask = torch.ones([B, C, H, W], device=img.device) | |
mask[..., t:t + h, l:l + w] = 0 | |
return mask, t, t + h, l, l + w | |
class mask_generator: | |
def __init__(self, mask_type, mask_len_range=None, mask_prob_range=None, | |
image_size=256, margin=(16, 16)): | |
""" | |
(mask_len_range): given in (min, max) tuple. | |
Specifies the range of box size in each dimension | |
(mask_prob_range): for the case of random masking, | |
specify the probability of individual pixels being masked | |
""" | |
assert mask_type in ['box', 'random', 'both', 'extreme'] | |
self.mask_type = mask_type | |
self.mask_len_range = mask_len_range | |
self.mask_prob_range = mask_prob_range | |
self.image_size = image_size | |
self.margin = margin | |
def _retrieve_box(self, img): | |
l, h = self.mask_len_range | |
l, h = int(l), int(h) | |
mask_h = np.random.randint(l, h) | |
mask_w = np.random.randint(l, h) | |
mask, t, tl, w, wh = random_sq_bbox(img, | |
mask_shape=(mask_h, mask_w), | |
image_size=self.image_size, | |
margin=self.margin) | |
return mask, t, tl, w, wh | |
def _retrieve_random(self, img): | |
total = self.image_size ** 2 | |
# random pixel sampling | |
l, h = self.mask_prob_range | |
prob = np.random.uniform(l, h) | |
mask_vec = torch.ones([1, self.image_size * self.image_size]) | |
samples = np.random.choice(self.image_size * self.image_size, int(total * prob), replace=False) | |
mask_vec[:, samples] = 0 | |
mask_b = mask_vec.view(1, self.image_size, self.image_size) | |
mask_b = mask_b.repeat(3, 1, 1) | |
mask = torch.ones_like(img, device=img.device) | |
mask[:, ...] = mask_b | |
return mask | |
def __call__(self, img): | |
if self.mask_type == 'random': | |
mask = self._retrieve_random(img) | |
return mask | |
elif self.mask_type == 'box': | |
mask, t, th, w, wl = self._retrieve_box(img) | |
return mask | |
elif self.mask_type == 'extreme': | |
mask, t, th, w, wl = self._retrieve_box(img) | |
mask = 1. - mask | |
return mask | |
def unnormalize(img, s=0.95): | |
scaling = torch.quantile(img.abs(), s) | |
return img / scaling | |
def normalize(img, s=0.95): | |
scaling = torch.quantile(img.abs(), s) | |
return img * scaling | |
def dynamic_thresholding(img, s=0.95): | |
img = normalize(img, s=s) | |
return torch.clip(img, -1., 1.) | |
def get_gaussian_kernel(kernel_size=31, std=0.5): | |
n = np.zeros([kernel_size, kernel_size]) | |
n[kernel_size // 2, kernel_size // 2] = 1 | |
k = scipy.ndimage.gaussian_filter(n, sigma=std) | |
k = k.astype(np.float32) | |
return k | |
def init_kernel_torch(kernel, device="cuda:0"): | |
h, w = kernel.shape | |
kernel = Variable(torch.from_numpy(kernel).to(device), requires_grad=True) | |
kernel = kernel.view(1, 1, h, w) | |
kernel = kernel.repeat(1, 3, 1, 1) | |
return kernel | |
class Blurkernel(nn.Module): | |
def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None): | |
super().__init__() | |
self.blur_type = blur_type | |
self.kernel_size = kernel_size | |
self.std = std | |
self.device = device | |
self.seq = nn.Sequential( | |
nn.ReflectionPad2d(self.kernel_size // 2), | |
nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3) | |
) | |
self.weights_init() | |
def forward(self, x): | |
return self.seq(x) | |
def weights_init(self): | |
if self.blur_type == "gaussian": | |
n = np.zeros((self.kernel_size, self.kernel_size)) | |
n[self.kernel_size // 2, self.kernel_size // 2] = 1 | |
k = scipy.ndimage.gaussian_filter(n, sigma=self.std) | |
k = torch.from_numpy(k) | |
self.k = k | |
for name, f in self.named_parameters(): | |
f.data.copy_(k) | |
elif self.blur_type == "motion": | |
k = Kernel(size=(self.kernel_size, self.kernel_size), intensity=self.std).kernelMatrix | |
k = torch.from_numpy(k) | |
self.k = k | |
for name, f in self.named_parameters(): | |
f.data.copy_(k) | |
def update_weights(self, k): | |
if not torch.is_tensor(k): | |
k = torch.from_numpy(k).to(self.device) | |
for name, f in self.named_parameters(): | |
f.data.copy_(k) | |
def get_kernel(self): | |
return self.k | |
class exact_posterior(): | |
def __init__(self, betas, sigma_0, label_dim, input_dim): | |
self.betas = betas | |
self.sigma_0 = sigma_0 | |
self.label_dim = label_dim | |
self.input_dim = input_dim | |
def py_given_x0(self, x0, y, A, verbose=False): | |
norm_const = 1 / ((2 * np.pi) ** self.input_dim * self.sigma_0 ** 2) | |
exp_in = -1 / (2 * self.sigma_0 ** 2) * torch.linalg.norm(y - A(x0)) ** 2 | |
if not verbose: | |
return norm_const * torch.exp(exp_in) | |
else: | |
return norm_const * torch.exp(exp_in), norm_const, exp_in | |
def pxt_given_x0(self, x0, xt, t, verbose=False): | |
beta_t = self.betas[t] | |
norm_const = 1 / ((2 * np.pi) ** self.label_dim * beta_t) | |
exp_in = -1 / (2 * beta_t) * torch.linalg.norm(xt - np.sqrt(1 - beta_t) * x0) ** 2 | |
if not verbose: | |
return norm_const * torch.exp(exp_in) | |
else: | |
return norm_const * torch.exp(exp_in), norm_const, exp_in | |
def prod_logsumexp(self, x0, xt, y, A, t): | |
py_given_x0_density, pyx0_nc, pyx0_ei = self.py_given_x0(x0, y, A, verbose=True) | |
pxt_given_x0_density, pxtx0_nc, pxtx0_ei = self.pxt_given_x0(x0, xt, t, verbose=True) | |
summand = (pyx0_nc * pxtx0_nc) * torch.exp(-pxtx0_ei - pxtx0_ei) | |
return torch.logsumexp(summand, dim=0) | |
def map2tensor(gray_map): | |
"""Move gray maps to GPU, no normalization is done""" | |
return torch.FloatTensor(gray_map).unsqueeze(0).unsqueeze(0).cuda() | |
def create_penalty_mask(k_size, penalty_scale): | |
"""Generate a mask of weights penalizing values close to the boundaries""" | |
center_size = k_size // 2 + k_size % 2 | |
mask = create_gaussian(size=k_size, sigma1=k_size, is_tensor=False) | |
mask = 1 - mask / np.max(mask) | |
margin = (k_size - center_size) // 2 - 1 | |
mask[margin:-margin, margin:-margin] = 0 | |
return penalty_scale * mask | |
def create_gaussian(size, sigma1, sigma2=-1, is_tensor=False): | |
"""Return a Gaussian""" | |
func1 = [np.exp(-z ** 2 / (2 * sigma1 ** 2)) / np.sqrt(2 * np.pi * sigma1 ** 2) for z in | |
range(-size // 2 + 1, size // 2 + 1)] | |
func2 = func1 if sigma2 == -1 else [np.exp(-z ** 2 / (2 * sigma2 ** 2)) / np.sqrt(2 * np.pi * sigma2 ** 2) for z in | |
range(-size // 2 + 1, size // 2 + 1)] | |
return torch.FloatTensor(np.outer(func1, func2)).cuda() if is_tensor else np.outer(func1, func2) | |
def total_variation_loss(img, weight): | |
tv_h = ((img[:, :, 1:, :] - img[:, :, :-1, :]).pow(2)).mean() | |
tv_w = ((img[:, :, :, 1:] - img[:, :, :, :-1]).pow(2)).mean() | |
return weight * (tv_h + tv_w) | |
if __name__ == '__main__': | |
import numpy as np | |
from torch import nn | |
import matplotlib.pyplot as plt | |
device = 'cuda:0' | |
load_path = '/media/harry/tomo/FFHQ/256/test/00000.png' | |
img = torch.tensor(plt.imread(load_path)[:, :, :3]) # rgb | |
img = torch.permute(img, (2, 0, 1)).view(1, 3, 256, 256).to(device) | |
mask_len_range = (32, 128) | |
mask_prob_range = (0.3, 0.7) | |
image_size = 256 | |
# mask | |
mask_gen = mask_generator( | |
mask_len_range=mask_len_range, | |
mask_prob_range=mask_prob_range, | |
image_size=image_size | |
) | |
mask = mask_gen(img) | |
mask = np.transpose(mask.squeeze().cpu().detach().numpy(), (1, 2, 0)) | |
plt.imshow(mask) | |
plt.show() | |