Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# Description: | |
import warnings | |
from math import ceil | |
from fractions import Fraction | |
from . import interp_methods | |
class NoneClass: | |
pass | |
try: | |
import torch | |
from torch import nn | |
nnModuleWrapped = nn.Module | |
except ImportError: | |
warnings.warn('No PyTorch found, will work only with Numpy') | |
torch = None | |
nnModuleWrapped = NoneClass | |
try: | |
import numpy | |
except ImportError: | |
warnings.warn('No Numpy found, will work only with PyTorch') | |
numpy = None | |
if numpy is None and torch is None: | |
raise ImportError("Must have either Numpy or PyTorch but both not found") | |
def resize(input, scale_factors=None, out_shape=None, | |
interp_method=interp_methods.cubic, support_sz=None, | |
antialiasing=True, by_convs=False, scale_tolerance=None, | |
max_numerator=10, pad_mode='replicate'): | |
# get properties of the input tensor | |
in_shape, n_dims = input.shape, input.ndim | |
# fw stands for framework that can be either numpy or torch, | |
# determined by the input type | |
fw = numpy if type(input) is numpy.ndarray else torch | |
eps = fw.finfo(fw.float32).eps | |
device = input.device if fw is torch else None | |
# set missing scale factors or output shapem one according to another, | |
# scream if both missing. this is also where all the defults policies | |
# take place. also handling the by_convs attribute carefully. | |
scale_factors, out_shape, by_convs = set_scale_and_out_sz(in_shape, | |
out_shape, | |
scale_factors, | |
by_convs, | |
scale_tolerance, | |
max_numerator, | |
eps, fw) | |
# sort indices of dimensions according to scale of each dimension. | |
# since we are going dim by dim this is efficient | |
sorted_filtered_dims_and_scales = [(dim, scale_factors[dim], by_convs[dim], | |
in_shape[dim], out_shape[dim]) | |
for dim in sorted(range(n_dims), | |
key=lambda ind: scale_factors[ind]) | |
if scale_factors[dim] != 1.] | |
# unless support size is specified by the user, it is an attribute | |
# of the interpolation method | |
if support_sz is None: | |
support_sz = interp_method.support_sz | |
# output begins identical to input and changes with each iteration | |
output = input | |
# iterate over dims | |
for (dim, scale_factor, dim_by_convs, in_sz, out_sz | |
) in sorted_filtered_dims_and_scales: | |
# STEP 1- PROJECTED GRID: The non-integer locations of the projection | |
# of output pixel locations to the input tensor | |
projected_grid = get_projected_grid(in_sz, out_sz, | |
scale_factor, fw, dim_by_convs, | |
device) | |
# STEP 1.5: ANTIALIASING- If antialiasing is taking place, we modify | |
# the window size and the interpolation method (see inside function) | |
cur_interp_method, cur_support_sz = apply_antialiasing_if_needed( | |
interp_method, | |
support_sz, | |
scale_factor, | |
antialiasing) | |
# STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels | |
# that influence it. Also calculate needed padding and update grid | |
# accoedingly | |
field_of_view = get_field_of_view(projected_grid, cur_support_sz, fw, | |
eps, device) | |
# STEP 2.5- CALCULATE PAD AND UPDATE: according to the field of view, | |
# the input should be padded to handle the boundaries, coordinates | |
# should be updated. actual padding only occurs when weights are | |
# aplied (step 4). if using by_convs for this dim, then we need to | |
# calc right and left boundaries for each filter instead. | |
pad_sz, projected_grid, field_of_view = calc_pad_sz(in_sz, out_sz, | |
field_of_view, | |
projected_grid, | |
scale_factor, | |
dim_by_convs, fw, | |
device) | |
# STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in | |
# the field of view for each output pixel | |
weights = get_weights(cur_interp_method, projected_grid, field_of_view) | |
# STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying | |
# its set of weights with the pixel values in its field of view. | |
# We now multiply the fields of view with their matching weights. | |
# We do this by tensor multiplication and broadcasting. | |
# if by_convs is true for this dim, then we do this action by | |
# convolutions. this is equivalent but faster. | |
if not dim_by_convs: | |
output = apply_weights(output, field_of_view, weights, dim, n_dims, | |
pad_sz, pad_mode, fw) | |
else: | |
output = apply_convs(output, scale_factor, in_sz, out_sz, weights, | |
dim, pad_sz, pad_mode, fw) | |
return output | |
def get_projected_grid(in_sz, out_sz, scale_factor, fw, by_convs, device=None): | |
# we start by having the ouput coordinates which are just integer locations | |
# in the special case when usin by_convs, we only need two cycles of grid | |
# points. the first and last. | |
grid_sz = out_sz if not by_convs else scale_factor.numerator | |
out_coordinates = fw_arange(grid_sz, fw, device) | |
# This is projecting the ouput pixel locations in 1d to the input tensor, | |
# as non-integer locations. | |
# the following fomrula is derived in the paper | |
# "From Discrete to Continuous Convolutions" by Shocher et al. | |
return (out_coordinates / float(scale_factor) + | |
(in_sz - 1) / 2 - (out_sz - 1) / (2 * float(scale_factor))) | |
def get_field_of_view(projected_grid, cur_support_sz, fw, eps, device): | |
# for each output pixel, map which input pixels influence it, in 1d. | |
# we start by calculating the leftmost neighbor, using half of the window | |
# size (eps is for when boundary is exact int) | |
left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw) | |
# then we simply take all the pixel centers in the field by counting | |
# window size pixels from the left boundary | |
ordinal_numbers = fw_arange(ceil(cur_support_sz - eps), fw, device) | |
return left_boundaries[:, None] + ordinal_numbers | |
def calc_pad_sz(in_sz, out_sz, field_of_view, projected_grid, scale_factor, | |
dim_by_convs, fw, device): | |
if not dim_by_convs: | |
# determine padding according to neighbor coords out of bound. | |
# this is a generalized notion of padding, when pad<0 it means crop | |
pad_sz = [-field_of_view[0, 0].item(), | |
field_of_view[-1, -1].item() - in_sz + 1] | |
# since input image will be changed by padding, coordinates of both | |
# field_of_view and projected_grid need to be updated | |
field_of_view += pad_sz[0] | |
projected_grid += pad_sz[0] | |
else: | |
# only used for by_convs, to calc the boundaries of each filter the | |
# number of distinct convolutions is the numerator of the scale factor | |
num_convs, stride = scale_factor.numerator, scale_factor.denominator | |
# calculate left and right boundaries for each conv. left can also be | |
# negative right can be bigger than in_sz. such cases imply padding if | |
# needed. however if# both are in-bounds, it means we need to crop, | |
# practically apply the conv only on part of the image. | |
left_pads = -field_of_view[:, 0] | |
# next calc is tricky, explanation by rows: | |
# 1) counting output pixels between the first position of each filter | |
# to the right boundary of the input | |
# 2) dividing it by number of filters to count how many 'jumps' | |
# each filter does | |
# 3) multiplying by the stride gives us the distance over the input | |
# coords done by all these jumps for each filter | |
# 4) to this distance we add the right boundary of the filter when | |
# placed in its leftmost position. so now we get the right boundary | |
# of that filter in input coord. | |
# 5) the padding size needed is obtained by subtracting the rightmost | |
# input coordinate. if the result is positive padding is needed. if | |
# negative then negative padding means shaving off pixel columns. | |
right_pads = (((out_sz - fw_arange(num_convs, fw, device) - 1) # (1) | |
// num_convs) # (2) | |
* stride # (3) | |
+ field_of_view[:, -1] # (4) | |
- in_sz + 1) # (5) | |
# in the by_convs case pad_sz is a list of left-right pairs. one per | |
# each filter | |
pad_sz = list(zip(left_pads, right_pads)) | |
return pad_sz, projected_grid, field_of_view | |
def get_weights(interp_method, projected_grid, field_of_view): | |
# the set of weights per each output pixels is the result of the chosen | |
# interpolation method applied to the distances between projected grid | |
# locations and the pixel-centers in the field of view (distances are | |
# directed, can be positive or negative) | |
weights = interp_method(projected_grid[:, None] - field_of_view) | |
# we now carefully normalize the weights to sum to 1 per each output pixel | |
sum_weights = weights.sum(1, keepdims=True) | |
sum_weights[sum_weights == 0] = 1 | |
return weights / sum_weights | |
def apply_weights(input, field_of_view, weights, dim, n_dims, pad_sz, pad_mode, | |
fw): | |
# for this operation we assume the resized dim is the first one. | |
# so we transpose and will transpose back after multiplying | |
tmp_input = fw_swapaxes(input, dim, 0, fw) | |
# apply padding | |
tmp_input = fw_pad(tmp_input, fw, pad_sz, pad_mode) | |
# field_of_view is a tensor of order 2: for each output (1d location | |
# along cur dim)- a list of 1d neighbors locations. | |
# note that this whole operations is applied to each dim separately, | |
# this is why it is all in 1d. | |
# neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1: | |
# for each output pixel (this time indicated in all dims), these are the | |
# values of the neighbors in the 1d field of view. note that we only | |
# consider neighbors along the current dim, but such set exists for every | |
# multi-dim location, hence the final tensor order is image_dims+1. | |
neighbors = tmp_input[field_of_view] | |
# weights is an order 2 tensor: for each output location along 1d- a list | |
# of weights matching the field of view. we augment it with ones, for | |
# broadcasting, so that when multiplies some tensor the weights affect | |
# only its first dim. | |
tmp_weights = fw.reshape(weights, (*weights.shape, *[1] * (n_dims - 1))) | |
# now we simply multiply the weights with the neighbors, and then sum | |
# along the field of view, to get a single value per out pixel | |
tmp_output = (neighbors * tmp_weights).sum(1) | |
# we transpose back the resized dim to its original position | |
return fw_swapaxes(tmp_output, 0, dim, fw) | |
def apply_convs(input, scale_factor, in_sz, out_sz, weights, dim, pad_sz, | |
pad_mode, fw): | |
# for this operations we assume the resized dim is the last one. | |
# so we transpose and will transpose back after multiplying | |
input = fw_swapaxes(input, dim, -1, fw) | |
# the stride for all convs is the denominator of the scale factor | |
stride, num_convs = scale_factor.denominator, scale_factor.numerator | |
# prepare an empty tensor for the output | |
tmp_out_shape = list(input.shape) | |
tmp_out_shape[-1] = out_sz | |
tmp_output = fw_empty(tuple(tmp_out_shape), fw, input.device) | |
# iterate over the conv operations. we have as many as the numerator | |
# of the scale-factor. for each we need boundaries and a filter. | |
for conv_ind, (pad_sz, filt) in enumerate(zip(pad_sz, weights)): | |
# apply padding (we pad last dim, padding can be negative) | |
pad_dim = input.ndim - 1 | |
tmp_input = fw_pad(input, fw, pad_sz, pad_mode, dim=pad_dim) | |
# apply convolution over last dim. store in the output tensor with | |
# positional strides so that when the loop is comlete conv results are | |
# interwind | |
tmp_output[..., conv_ind::num_convs] = fw_conv(tmp_input, filt, stride) | |
return fw_swapaxes(tmp_output, -1, dim, fw) | |
def set_scale_and_out_sz(in_shape, out_shape, scale_factors, by_convs, | |
scale_tolerance, max_numerator, eps, fw): | |
# eventually we must have both scale-factors and out-sizes for all in/out | |
# dims. however, we support many possible partial arguments | |
if scale_factors is None and out_shape is None: | |
raise ValueError("either scale_factors or out_shape should be " | |
"provided") | |
if out_shape is not None: | |
# if out_shape has less dims than in_shape, we defaultly resize the | |
# first dims for numpy and last dims for torch | |
out_shape = (list(out_shape) + list(in_shape[len(out_shape):]) | |
if fw is numpy | |
else list(in_shape[:-len(out_shape)]) + list(out_shape)) | |
if scale_factors is None: | |
# if no scale given, we calculate it as the out to in ratio | |
# (not recomended) | |
scale_factors = [out_sz / in_sz for out_sz, in_sz | |
in zip(out_shape, in_shape)] | |
if scale_factors is not None: | |
# by default, if a single number is given as scale, we assume resizing | |
# two dims (most common are images with 2 spatial dims) | |
scale_factors = (scale_factors | |
if isinstance(scale_factors, (list, tuple)) | |
else [scale_factors, scale_factors]) | |
# if less scale_factors than in_shape dims, we defaultly resize the | |
# first dims for numpy and last dims for torch | |
scale_factors = (list(scale_factors) + [1] * | |
(len(in_shape) - len(scale_factors)) if fw is numpy | |
else [1] * (len(in_shape) - len(scale_factors)) + | |
list(scale_factors)) | |
if out_shape is None: | |
# when no out_shape given, it is calculated by multiplying the | |
# scale by the in_shape (not recomended) | |
out_shape = [ceil(scale_factor * in_sz) | |
for scale_factor, in_sz in | |
zip(scale_factors, in_shape)] | |
# next part intentionally after out_shape determined for stability | |
# we fix by_convs to be a list of truth values in case it is not | |
if not isinstance(by_convs, (list, tuple)): | |
by_convs = [by_convs] * len(out_shape) | |
# next loop fixes the scale for each dim to be either frac or float. | |
# this is determined by by_convs and by tolerance for scale accuracy. | |
for ind, (sf, dim_by_convs) in enumerate(zip(scale_factors, by_convs)): | |
# first we fractionaize | |
if dim_by_convs: | |
frac = Fraction(1 / sf).limit_denominator(max_numerator) | |
frac = Fraction(numerator=frac.denominator, denominator=frac.numerator) | |
# if accuracy is within tolerance scale will be frac. if not, then | |
# it will be float and the by_convs attr will be set false for | |
# this dim | |
if scale_tolerance is None: | |
scale_tolerance = eps | |
if dim_by_convs and abs(frac - sf) < scale_tolerance: | |
scale_factors[ind] = frac | |
else: | |
scale_factors[ind] = float(sf) | |
by_convs[ind] = False | |
return scale_factors, out_shape, by_convs | |
def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor, | |
antialiasing): | |
# antialiasing is "stretching" the field of view according to the scale | |
# factor (only for downscaling). this is low-pass filtering. this | |
# requires modifying both the interpolation (stretching the 1d | |
# function and multiplying by the scale-factor) and the window size. | |
scale_factor = float(scale_factor) | |
if scale_factor >= 1.0 or not antialiasing: | |
return interp_method, support_sz | |
cur_interp_method = (lambda arg: scale_factor * | |
interp_method(scale_factor * arg)) | |
cur_support_sz = support_sz / scale_factor | |
return cur_interp_method, cur_support_sz | |
def fw_ceil(x, fw): | |
if fw is numpy: | |
return fw.int_(fw.ceil(x)) | |
else: | |
return x.ceil().long() | |
def fw_floor(x, fw): | |
if fw is numpy: | |
return fw.int_(fw.floor(x)) | |
else: | |
return x.floor().long() | |
def fw_cat(x, fw): | |
if fw is numpy: | |
return fw.concatenate(x) | |
else: | |
return fw.cat(x) | |
def fw_swapaxes(x, ax_1, ax_2, fw): | |
if fw is numpy: | |
return fw.swapaxes(x, ax_1, ax_2) | |
else: | |
return x.transpose(ax_1, ax_2) | |
def fw_pad(x, fw, pad_sz, pad_mode, dim=0): | |
if pad_sz == (0, 0): | |
return x | |
if fw is numpy: | |
pad_vec = [(0, 0)] * x.ndim | |
pad_vec[dim] = pad_sz | |
return fw.pad(x, pad_width=pad_vec, mode=pad_mode) | |
else: | |
if x.ndim < 3: | |
x = x[None, None, ...] | |
pad_vec = [0] * ((x.ndim - 2) * 2) | |
pad_vec[0:2] = pad_sz | |
return fw.nn.functional.pad(x.transpose(dim, -1), pad=pad_vec, | |
mode=pad_mode).transpose(dim, -1) | |
def fw_conv(input, filter, stride): | |
# we want to apply 1d conv to any nd array. the way to do it is to reshape | |
# the input to a 4D tensor. first two dims are singeletons, 3rd dim stores | |
# all the spatial dims that we are not convolving along now. then we can | |
# apply conv2d with a 1xK filter. This convolves the same way all the other | |
# dims stored in the 3d dim. like depthwise conv over these. | |
# TODO: numpy support | |
reshaped_input = input.reshape(1, 1, -1, input.shape[-1]) | |
reshaped_output = torch.nn.functional.conv2d(reshaped_input, | |
filter.view(1, 1, 1, -1), | |
stride=(1, stride)) | |
return reshaped_output.reshape(*input.shape[:-1], -1) | |
def fw_arange(upper_bound, fw, device): | |
if fw is numpy: | |
return fw.arange(upper_bound) | |
else: | |
return fw.arange(upper_bound, device=device) | |
def fw_empty(shape, fw, device): | |
if fw is numpy: | |
return fw.empty(shape) | |
else: | |
return fw.empty(size=(*shape,), device=device) | |