Spaces:
Runtime error
Runtime error
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |
# | |
# This work is licensed under the Creative Commons Attribution-NonCommercial | |
# 4.0 International License. To view a copy of this license, visit | |
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to | |
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. | |
import numpy as np | |
import scipy.ndimage | |
#---------------------------------------------------------------------------- | |
def get_descriptors_for_minibatch(minibatch, nhood_size, nhoods_per_image): | |
S = minibatch.shape # (minibatch, channel, height, width) | |
assert len(S) == 4 and S[1] == 3 | |
N = nhoods_per_image * S[0] | |
H = nhood_size // 2 | |
nhood, chan, x, y = np.ogrid[0:N, 0:3, -H:H+1, -H:H+1] | |
img = nhood // nhoods_per_image | |
x = x + np.random.randint(H, S[3] - H, size=(N, 1, 1, 1)) | |
y = y + np.random.randint(H, S[2] - H, size=(N, 1, 1, 1)) | |
idx = ((img * S[1] + chan) * S[2] + y) * S[3] + x | |
return minibatch.flat[idx] | |
#---------------------------------------------------------------------------- | |
def finalize_descriptors(desc): | |
if isinstance(desc, list): | |
desc = np.concatenate(desc, axis=0) | |
assert desc.ndim == 4 # (neighborhood, channel, height, width) | |
desc -= np.mean(desc, axis=(0, 2, 3), keepdims=True) | |
desc /= np.std(desc, axis=(0, 2, 3), keepdims=True) | |
desc = desc.reshape(desc.shape[0], -1) | |
return desc | |
#---------------------------------------------------------------------------- | |
def sliced_wasserstein(A, B, dir_repeats, dirs_per_repeat): | |
assert A.ndim == 2 and A.shape == B.shape # (neighborhood, descriptor_component) | |
results = [] | |
for repeat in range(dir_repeats): | |
dirs = np.random.randn(A.shape[1], dirs_per_repeat) # (descriptor_component, direction) | |
dirs /= np.sqrt(np.sum(np.square(dirs), axis=0, keepdims=True)) # normalize descriptor components for each direction | |
dirs = dirs.astype(np.float32) | |
projA = np.matmul(A, dirs) # (neighborhood, direction) | |
projB = np.matmul(B, dirs) | |
projA = np.sort(projA, axis=0) # sort neighborhood projections for each direction | |
projB = np.sort(projB, axis=0) | |
dists = np.abs(projA - projB) # pointwise wasserstein distances | |
results.append(np.mean(dists)) # average over neighborhoods and directions | |
return np.mean(results) # average over repeats | |
#---------------------------------------------------------------------------- | |
def downscale_minibatch(minibatch, lod): | |
if lod == 0: | |
return minibatch | |
t = minibatch.astype(np.float32) | |
for i in range(lod): | |
t = (t[:, :, 0::2, 0::2] + t[:, :, 0::2, 1::2] + t[:, :, 1::2, 0::2] + t[:, :, 1::2, 1::2]) * 0.25 | |
return np.round(t).clip(0, 255).astype(np.uint8) | |
#---------------------------------------------------------------------------- | |
gaussian_filter = np.float32([ | |
[1, 4, 6, 4, 1], | |
[4, 16, 24, 16, 4], | |
[6, 24, 36, 24, 6], | |
[4, 16, 24, 16, 4], | |
[1, 4, 6, 4, 1]]) / 256.0 | |
def pyr_down(minibatch): # matches cv2.pyrDown() | |
assert minibatch.ndim == 4 | |
return scipy.ndimage.convolve(minibatch, gaussian_filter[np.newaxis, np.newaxis, :, :], mode='mirror')[:, :, ::2, ::2] | |
def pyr_up(minibatch): # matches cv2.pyrUp() | |
assert minibatch.ndim == 4 | |
S = minibatch.shape | |
res = np.zeros((S[0], S[1], S[2] * 2, S[3] * 2), minibatch.dtype) | |
res[:, :, ::2, ::2] = minibatch | |
return scipy.ndimage.convolve(res, gaussian_filter[np.newaxis, np.newaxis, :, :] * 4.0, mode='mirror') | |
def generate_laplacian_pyramid(minibatch, num_levels): | |
pyramid = [np.float32(minibatch)] | |
for i in range(1, num_levels): | |
pyramid.append(pyr_down(pyramid[-1])) | |
pyramid[-2] -= pyr_up(pyramid[-1]) | |
return pyramid | |
def reconstruct_laplacian_pyramid(pyramid): | |
minibatch = pyramid[-1] | |
for level in pyramid[-2::-1]: | |
minibatch = pyr_up(minibatch) + level | |
return minibatch | |
#---------------------------------------------------------------------------- | |
class API: | |
def __init__(self, num_images, image_shape, image_dtype, minibatch_size): | |
self.nhood_size = 7 | |
self.nhoods_per_image = 128 | |
self.dir_repeats = 4 | |
self.dirs_per_repeat = 128 | |
self.resolutions = [] | |
res = image_shape[1] | |
while res >= 16: | |
self.resolutions.append(res) | |
res //= 2 | |
def get_metric_names(self): | |
return ['SWDx1e3_%d' % res for res in self.resolutions] + ['SWDx1e3_avg'] | |
def get_metric_formatting(self): | |
return ['%-13.4f'] * len(self.get_metric_names()) | |
def begin(self, mode): | |
assert mode in ['warmup', 'reals', 'fakes'] | |
self.descriptors = [[] for res in self.resolutions] | |
def feed(self, mode, minibatch): | |
for lod, level in enumerate(generate_laplacian_pyramid(minibatch, len(self.resolutions))): | |
desc = get_descriptors_for_minibatch(level, self.nhood_size, self.nhoods_per_image) | |
self.descriptors[lod].append(desc) | |
def end(self, mode): | |
desc = [finalize_descriptors(d) for d in self.descriptors] | |
del self.descriptors | |
if mode in ['warmup', 'reals']: | |
self.desc_real = desc | |
dist = [sliced_wasserstein(dreal, dfake, self.dir_repeats, self.dirs_per_repeat) for dreal, dfake in zip(self.desc_real, desc)] | |
del desc | |
dist = [d * 1e3 for d in dist] # multiply by 10^3 | |
return dist + [np.mean(dist)] | |
#---------------------------------------------------------------------------- | |