sat3density / imaginaire /layers /residual_deep.py
venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
from imaginaire.third_party.upfirdn2d import BlurDownsample, BlurUpsample
from .conv import Conv2dBlock
class _BaseDeepResBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride, padding, dilation, groups, bias, padding_mode,
weight_norm_type, weight_norm_params,
activation_norm_type, activation_norm_params,
skip_activation_norm, skip_nonlinearity,
nonlinearity, inplace_nonlinearity, apply_noise,
hidden_channels_equal_out_channels,
order, block, learn_shortcut, output_scale, skip_block=None,
blur=True, border_free=True, resample_first=True,
skip_weight_norm=True, hidden_channel_ratio=4):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.output_scale = output_scale
self.resample_first = resample_first
self.stride = stride
self.blur = blur
self.border_free = border_free
assert not border_free
if skip_block is None:
skip_block = block
if order == 'pre_act':
order = 'NACNAC'
if isinstance(bias, bool):
# The bias for conv_block_0, conv_block_1, and conv_block_s.
biases = [bias, bias, bias]
elif isinstance(bias, list):
if len(bias) == 3:
biases = bias
else:
raise ValueError('Bias list must be 3.')
else:
raise ValueError('Bias must be either an integer or s list.')
self.learn_shortcut = learn_shortcut
if len(order) > 6 or len(order) < 5:
raise ValueError('order must be either 5 or 6 characters')
hidden_channels = in_channels // hidden_channel_ratio
# Parameters.
residual_params = {}
shortcut_params = {}
base_params = dict(dilation=dilation,
groups=groups,
padding_mode=padding_mode)
residual_params.update(base_params)
residual_params.update(
dict(activation_norm_type=activation_norm_type,
activation_norm_params=activation_norm_params,
weight_norm_type=weight_norm_type,
weight_norm_params=weight_norm_params,
apply_noise=apply_noise)
)
shortcut_params.update(base_params)
shortcut_params.update(dict(kernel_size=1))
if skip_activation_norm:
shortcut_params.update(
dict(activation_norm_type=activation_norm_type,
activation_norm_params=activation_norm_params,
apply_noise=False))
if skip_weight_norm:
shortcut_params.update(
dict(weight_norm_type=weight_norm_type,
weight_norm_params=weight_norm_params))
# Residual branch.
if order.find('A') < order.find('C') and \
(activation_norm_type == '' or activation_norm_type == 'none'):
# Nonlinearity is the first operation in the residual path.
# In-place nonlinearity will modify the input variable and cause
# backward error.
first_inplace = False
else:
first_inplace = inplace_nonlinearity
(first_stride, second_stride, shortcut_stride,
first_blur, second_blur, shortcut_blur) = self._get_stride_blur()
self.conv_block_1x1_in = block(
in_channels, hidden_channels,
1, 1, 0,
bias=biases[0],
nonlinearity=nonlinearity,
order=order[0:3],
inplace_nonlinearity=first_inplace,
**residual_params
)
self.conv_block_0 = block(
hidden_channels, hidden_channels,
kernel_size=2 if self.border_free and first_stride < 1 else
kernel_size,
padding=padding,
bias=biases[0],
nonlinearity=nonlinearity,
order=order[0:3],
inplace_nonlinearity=inplace_nonlinearity,
stride=first_stride,
blur=first_blur,
**residual_params
)
self.conv_block_1 = block(
hidden_channels, hidden_channels,
kernel_size=kernel_size,
padding=padding,
bias=biases[1],
nonlinearity=nonlinearity,
order=order[3:],
inplace_nonlinearity=inplace_nonlinearity,
stride=second_stride,
blur=second_blur,
**residual_params
)
self.conv_block_1x1_out = block(
hidden_channels, out_channels,
1, 1, 0,
bias=biases[1],
nonlinearity=nonlinearity,
order=order[0:3],
inplace_nonlinearity=inplace_nonlinearity,
**residual_params
)
# Shortcut branch.
if self.learn_shortcut:
if skip_nonlinearity:
skip_nonlinearity_type = nonlinearity
else:
skip_nonlinearity_type = ''
self.conv_block_s = skip_block(in_channels, out_channels,
bias=biases[2],
nonlinearity=skip_nonlinearity_type,
order=order[0:3],
stride=shortcut_stride,
blur=shortcut_blur,
**shortcut_params)
elif in_channels < out_channels:
if skip_nonlinearity:
skip_nonlinearity_type = nonlinearity
else:
skip_nonlinearity_type = ''
self.conv_block_s = skip_block(in_channels,
out_channels - in_channels,
bias=biases[2],
nonlinearity=skip_nonlinearity_type,
order=order[0:3],
stride=shortcut_stride,
blur=shortcut_blur,
**shortcut_params)
# Whether this block expects conditional inputs.
self.conditional = \
getattr(self.conv_block_0, 'conditional', False) or \
getattr(self.conv_block_1, 'conditional', False) or \
getattr(self.conv_block_1x1_in, 'conditional', False) or \
getattr(self.conv_block_1x1_out, 'conditional', False)
def _get_stride_blur(self):
if self.stride > 1:
# Downsampling.
first_stride, second_stride = 1, self.stride
first_blur, second_blur = False, self.blur
shortcut_blur = False
shortcut_stride = 1
if self.blur:
# The shortcut branch uses blur_downsample + stride-1 conv
if self.border_free:
self.resample = nn.AvgPool2d(2)
else:
self.resample = BlurDownsample()
else:
shortcut_stride = self.stride
self.resample = nn.AvgPool2d(2)
elif self.stride < 1:
# Upsampling.
first_stride, second_stride = self.stride, 1
first_blur, second_blur = self.blur, False
shortcut_blur = False
shortcut_stride = 1
if self.blur:
# The shortcut branch uses blur_upsample + stride-1 conv
if self.border_free:
self.resample = nn.Upsample(scale_factor=2,
mode='bilinear')
else:
self.resample = BlurUpsample()
else:
shortcut_stride = self.stride
self.resample = nn.Upsample(scale_factor=2)
else:
first_stride = second_stride = 1
first_blur = second_blur = False
shortcut_stride = 1
shortcut_blur = False
self.resample = None
return (first_stride, second_stride, shortcut_stride,
first_blur, second_blur, shortcut_blur)
def conv_blocks(
self, x, *cond_inputs, separate_cond=False, **kw_cond_inputs
):
if separate_cond:
assert len(list(cond_inputs)) == 4
dx = self.conv_block_1x1_in(x, cond_inputs[0],
**kw_cond_inputs.get('kwargs_0', {}))
dx = self.conv_block_0(dx, cond_inputs[1],
**kw_cond_inputs.get('kwargs_1', {}))
dx = self.conv_block_1(dx, cond_inputs[2],
**kw_cond_inputs.get('kwargs_2', {}))
dx = self.conv_block_1x1_out(dx, cond_inputs[3],
**kw_cond_inputs.get('kwargs_3', {}))
else:
dx = self.conv_block_1x1_in(x, *cond_inputs, **kw_cond_inputs)
dx = self.conv_block_0(dx, *cond_inputs, **kw_cond_inputs)
dx = self.conv_block_1(dx, *cond_inputs, **kw_cond_inputs)
dx = self.conv_block_1x1_out(dx, *cond_inputs, **kw_cond_inputs)
return dx
def forward(self, x, *cond_inputs, do_checkpoint=False, **kw_cond_inputs):
if do_checkpoint:
dx = checkpoint(self.conv_blocks, x, *cond_inputs, **kw_cond_inputs)
else:
dx = self.conv_blocks(x, *cond_inputs, **kw_cond_inputs)
if self.resample_first and self.resample is not None:
x = self.resample(x)
if self.learn_shortcut:
x_shortcut = self.conv_block_s(
x, *cond_inputs, **kw_cond_inputs
)
elif self.in_channels < self.out_channels:
x_shortcut_pad = self.conv_block_s(
x, *cond_inputs, **kw_cond_inputs
)
x_shortcut = torch.cat((x, x_shortcut_pad), dim=1)
elif self.in_channels > self.out_channels:
x_shortcut = x[:, :self.out_channels, :, :]
else:
x_shortcut = x
if not self.resample_first and self.resample is not None:
x_shortcut = self.resample(x_shortcut)
output = x_shortcut + dx
return self.output_scale * output
def extra_repr(self):
s = 'output_scale={output_scale}'
return s.format(**self.__dict__)
class DeepRes2dBlock(_BaseDeepResBlock):
r"""Residual block for 2D input.
Args:
in_channels (int) : Number of channels in the input tensor.
out_channels (int) : Number of channels in the output tensor.
kernel_size (int, optional, default=3): Kernel size for the
convolutional filters in the residual link.
padding (int, optional, default=1): Padding size.
dilation (int, optional, default=1): Dilation factor.
groups (int, optional, default=1): Number of convolutional/linear
groups.
padding_mode (string, optional, default='zeros'): Type of padding:
``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
weight_norm_type (str, optional, default='none'):
Type of weight normalization.
``'none'``, ``'spectral'``, ``'weight'``
or ``'weight_demod'``.
weight_norm_params (obj, optional, default=None):
Parameters of weight normalization.
If not ``None``, ``weight_norm_params.__dict__`` will be used as
keyword arguments when initializing weight normalization.
activation_norm_type (str, optional, default='none'):
Type of activation normalization.
``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
``'layer'``, ``'layer_2d'``, ``'group'``, ``'adaptive'``,
``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
activation_norm_params (obj, optional, default=None):
Parameters of activation normalization.
If not ``None``, ``activation_norm_params.__dict__`` will be used as
keyword arguments when initializing activation normalization.
skip_activation_norm (bool, optional, default=True): If ``True`` and
``learn_shortcut`` is also ``True``, applies activation norm to the
learned shortcut connection.
skip_nonlinearity (bool, optional, default=True): If ``True`` and
``learn_shortcut`` is also ``True``, applies nonlinearity to the
learned shortcut connection.
nonlinearity (str, optional, default='none'):
Type of nonlinear activation function in the residual link.
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
inplace_nonlinearity (bool, optional, default=False): If ``True``,
set ``inplace=True`` when initializing the nonlinearity layers.
apply_noise (bool, optional, default=False): If ``True``, adds
Gaussian noise with learnable magnitude to the convolution output.
hidden_channels_equal_out_channels (bool, optional, default=False):
If ``True``, set the hidden channel number to be equal to the
output channel number. If ``False``, the hidden channel number
equals to the smaller of the input channel number and the
output channel number.
order (str, optional, default='CNACNA'): Order of operations
in the residual link.
``'C'``: convolution,
``'N'``: normalization,
``'A'``: nonlinear activation.
learn_shortcut (bool, optional, default=False): If ``True``, always use
a convolutional shortcut instead of an identity one, otherwise only
use a convolutional one if input and output have different number of
channels.
"""
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding=1, dilation=1, groups=1, bias=True,
padding_mode='zeros',
weight_norm_type='none', weight_norm_params=None,
activation_norm_type='none', activation_norm_params=None,
skip_activation_norm=True, skip_nonlinearity=False,
skip_weight_norm=True,
nonlinearity='leakyrelu', inplace_nonlinearity=False,
apply_noise=False, hidden_channels_equal_out_channels=False,
order='CNACNA', learn_shortcut=False, output_scale=1,
blur=True, resample_first=True, border_free=False):
super().__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias, padding_mode,
weight_norm_type, weight_norm_params,
activation_norm_type, activation_norm_params,
skip_activation_norm, skip_nonlinearity, nonlinearity,
inplace_nonlinearity, apply_noise,
hidden_channels_equal_out_channels, order, Conv2dBlock,
learn_shortcut, output_scale, blur=blur,
resample_first=resample_first, border_free=border_free,
skip_weight_norm=skip_weight_norm)