venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
from types import SimpleNamespace
import torch
from torch import nn
from .misc import ApplyNoise
from imaginaire.third_party.upfirdn2d.upfirdn2d import Blur
class ViT2dBlock(nn.Module):
r"""An abstract wrapper class that wraps a torch convolution or linear layer
with normalization and nonlinearity.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias, padding_mode,
weight_norm_type, weight_norm_params,
activation_norm_type, activation_norm_params,
nonlinearity, inplace_nonlinearity,
apply_noise, blur, order, input_dim, clamp,
blur_kernel=(1, 3, 3, 1), output_scale=None,
init_gain=1.0):
super().__init__()
from .nonlinearity import get_nonlinearity_layer
from .weight_norm import get_weight_norm_layer
from .activation_norm import get_activation_norm_layer
self.weight_norm_type = weight_norm_type
self.stride = stride
self.clamp = clamp
self.init_gain = init_gain
# Nonlinearity layer.
if 'fused' in nonlinearity:
# Fusing nonlinearity with bias.
lr_mul = getattr(weight_norm_params, 'lr_mul', 1)
conv_before_nonlinearity = order.find('C') < order.find('A')
if conv_before_nonlinearity:
assert bias
bias = False
channel = out_channels if conv_before_nonlinearity else in_channels
nonlinearity_layer = get_nonlinearity_layer(
nonlinearity, inplace=inplace_nonlinearity,
num_channels=channel, lr_mul=lr_mul)
else:
nonlinearity_layer = get_nonlinearity_layer(
nonlinearity, inplace=inplace_nonlinearity)
# Noise injection layer.
if apply_noise:
order = order.replace('C', 'CG')
noise_layer = ApplyNoise()
else:
noise_layer = None
# Convolutional layer.
if blur:
if stride == 2:
# Blur - Conv - Noise - Activate
p = (len(blur_kernel) - 2) + (kernel_size - 1)
pad0, pad1 = (p + 1) // 2, p // 2
padding = 0
blur_layer = Blur(
blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode
)
order = order.replace('C', 'BC')
elif stride == 0.5:
# Conv - Blur - Noise - Activate
padding = 0
p = (len(blur_kernel) - 2) - (kernel_size - 1)
pad0, pad1 = (p + 1) // 2 + 1, p // 2 + 1
blur_layer = Blur(
blur_kernel, pad=(pad0, pad1), padding_mode=padding_mode
)
order = order.replace('C', 'CB')
elif stride == 1:
# No blur for now
blur_layer = nn.Identity()
else:
raise NotImplementedError
else:
blur_layer = nn.Identity()
if weight_norm_params is None:
weight_norm_params = SimpleNamespace()
weight_norm = get_weight_norm_layer(
weight_norm_type, **vars(weight_norm_params))
conv_layer = weight_norm(self._get_conv_layer(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, bias, padding_mode, input_dim))
# Normalization layer.
conv_before_norm = order.find('C') < order.find('N')
norm_channels = out_channels if conv_before_norm else in_channels
if activation_norm_params is None:
activation_norm_params = SimpleNamespace()
activation_norm_layer = get_activation_norm_layer(
norm_channels,
activation_norm_type,
input_dim,
**vars(activation_norm_params))
# Mapping from operation names to layers.
mappings = {'C': {'conv': conv_layer},
'N': {'norm': activation_norm_layer},
'A': {'nonlinearity': nonlinearity_layer}}
mappings.update({'B': {'blur': blur_layer}})
mappings.update({'G': {'noise': noise_layer}})
# All layers in order.
self.layers = nn.ModuleDict()
for op in order:
if list(mappings[op].values())[0] is not None:
self.layers.update(mappings[op])
# Whether this block expects conditional inputs.
self.conditional = \
getattr(conv_layer, 'conditional', False) or \
getattr(activation_norm_layer, 'conditional', False)
if output_scale is not None:
self.output_scale = nn.Parameter(torch.tensor(output_scale))
else:
self.register_parameter("output_scale", None)
def forward(self, x, *cond_inputs, **kw_cond_inputs):
r"""
Args:
x (tensor): Input tensor.
cond_inputs (list of tensors) : Conditional input tensors.
kw_cond_inputs (dict) : Keyword conditional inputs.
"""
for key, layer in self.layers.items():
if getattr(layer, 'conditional', False):
# Layers that require conditional inputs.
x = layer(x, *cond_inputs, **kw_cond_inputs)
else:
x = layer(x)
if self.clamp is not None and isinstance(layer, nn.Conv2d):
x.clamp_(max=self.clamp)
if key == 'conv':
if self.output_scale is not None:
x = x * self.output_scale
return x
def _get_conv_layer(self, in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias, padding_mode,
input_dim):
# Returns the convolutional layer.
if input_dim == 0:
layer = nn.Linear(in_channels, out_channels, bias)
else:
if stride < 1: # Fractionally-strided convolution.
padding_mode = 'zeros'
assert padding == 0
layer_type = getattr(nn, f'ConvTranspose{input_dim}d')
stride = round(1 / stride)
else:
layer_type = getattr(nn, f'Conv{input_dim}d')
layer = layer_type(
in_channels, out_channels, kernel_size, stride, padding,
dilation=dilation, groups=groups, bias=bias,
padding_mode=padding_mode
)
return layer
def __repr__(self):
main_str = self._get_name() + '('
child_lines = []
for name, layer in self.layers.items():
mod_str = repr(layer)
if name == 'conv' and self.weight_norm_type != 'none' and \
self.weight_norm_type != '':
mod_str = mod_str[:-1] + \
', weight_norm={}'.format(self.weight_norm_type) + ')'
if name == 'conv' and getattr(layer, 'base_lr_mul', 1) != 1:
mod_str = mod_str[:-1] + \
', lr_mul={}'.format(layer.base_lr_mul) + ')'
mod_str = self._addindent(mod_str, 2)
child_lines.append(mod_str)
if len(child_lines) == 1:
main_str += child_lines[0]
else:
main_str += '\n ' + '\n '.join(child_lines) + '\n'
main_str += ')'
return main_str
@staticmethod
def _addindent(s_, numSpaces):
s = s_.split('\n')
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s