File size: 15,805 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint

from imaginaire.third_party.upfirdn2d import BlurDownsample, BlurUpsample
from .conv import Conv2dBlock


class _BaseDeepResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, padding, dilation, groups, bias, padding_mode,
                 weight_norm_type, weight_norm_params,
                 activation_norm_type, activation_norm_params,
                 skip_activation_norm, skip_nonlinearity,
                 nonlinearity, inplace_nonlinearity, apply_noise,
                 hidden_channels_equal_out_channels,
                 order, block, learn_shortcut, output_scale, skip_block=None,
                 blur=True, border_free=True, resample_first=True,
                 skip_weight_norm=True, hidden_channel_ratio=4):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.output_scale = output_scale
        self.resample_first = resample_first
        self.stride = stride
        self.blur = blur
        self.border_free = border_free
        assert not border_free
        if skip_block is None:
            skip_block = block

        if order == 'pre_act':
            order = 'NACNAC'
        if isinstance(bias, bool):
            # The bias for conv_block_0, conv_block_1, and conv_block_s.
            biases = [bias, bias, bias]
        elif isinstance(bias, list):
            if len(bias) == 3:
                biases = bias
            else:
                raise ValueError('Bias list must be 3.')
        else:
            raise ValueError('Bias must be either an integer or s list.')
        self.learn_shortcut = learn_shortcut
        if len(order) > 6 or len(order) < 5:
            raise ValueError('order must be either 5 or 6 characters')
        hidden_channels = in_channels // hidden_channel_ratio

        # Parameters.
        residual_params = {}
        shortcut_params = {}
        base_params = dict(dilation=dilation,
                           groups=groups,
                           padding_mode=padding_mode)
        residual_params.update(base_params)
        residual_params.update(
            dict(activation_norm_type=activation_norm_type,
                 activation_norm_params=activation_norm_params,
                 weight_norm_type=weight_norm_type,
                 weight_norm_params=weight_norm_params,
                 apply_noise=apply_noise)
        )
        shortcut_params.update(base_params)
        shortcut_params.update(dict(kernel_size=1))
        if skip_activation_norm:
            shortcut_params.update(
                dict(activation_norm_type=activation_norm_type,
                     activation_norm_params=activation_norm_params,
                     apply_noise=False))
        if skip_weight_norm:
            shortcut_params.update(
                dict(weight_norm_type=weight_norm_type,
                     weight_norm_params=weight_norm_params))

        # Residual branch.
        if order.find('A') < order.find('C') and \
                (activation_norm_type == '' or activation_norm_type == 'none'):
            # Nonlinearity is the first operation in the residual path.
            # In-place nonlinearity will modify the input variable and cause
            # backward error.
            first_inplace = False
        else:
            first_inplace = inplace_nonlinearity

        (first_stride, second_stride, shortcut_stride,
         first_blur, second_blur, shortcut_blur) = self._get_stride_blur()

        self.conv_block_1x1_in = block(
            in_channels, hidden_channels,
            1, 1, 0,
            bias=biases[0],
            nonlinearity=nonlinearity,
            order=order[0:3],
            inplace_nonlinearity=first_inplace,
            **residual_params
        )

        self.conv_block_0 = block(
            hidden_channels, hidden_channels,
            kernel_size=2 if self.border_free and first_stride < 1 else
            kernel_size,
            padding=padding,
            bias=biases[0],
            nonlinearity=nonlinearity,
            order=order[0:3],
            inplace_nonlinearity=inplace_nonlinearity,
            stride=first_stride,
            blur=first_blur,
            **residual_params
        )
        self.conv_block_1 = block(
            hidden_channels, hidden_channels,
            kernel_size=kernel_size,
            padding=padding,
            bias=biases[1],
            nonlinearity=nonlinearity,
            order=order[3:],
            inplace_nonlinearity=inplace_nonlinearity,
            stride=second_stride,
            blur=second_blur,
            **residual_params
        )

        self.conv_block_1x1_out = block(
            hidden_channels, out_channels,
            1, 1, 0,
            bias=biases[1],
            nonlinearity=nonlinearity,
            order=order[0:3],
            inplace_nonlinearity=inplace_nonlinearity,
            **residual_params
        )

        # Shortcut branch.
        if self.learn_shortcut:
            if skip_nonlinearity:
                skip_nonlinearity_type = nonlinearity
            else:
                skip_nonlinearity_type = ''
            self.conv_block_s = skip_block(in_channels, out_channels,
                                           bias=biases[2],
                                           nonlinearity=skip_nonlinearity_type,
                                           order=order[0:3],
                                           stride=shortcut_stride,
                                           blur=shortcut_blur,
                                           **shortcut_params)
        elif in_channels < out_channels:
            if skip_nonlinearity:
                skip_nonlinearity_type = nonlinearity
            else:
                skip_nonlinearity_type = ''
            self.conv_block_s = skip_block(in_channels,
                                           out_channels - in_channels,
                                           bias=biases[2],
                                           nonlinearity=skip_nonlinearity_type,
                                           order=order[0:3],
                                           stride=shortcut_stride,
                                           blur=shortcut_blur,
                                           **shortcut_params)

        # Whether this block expects conditional inputs.
        self.conditional = \
            getattr(self.conv_block_0, 'conditional', False) or \
            getattr(self.conv_block_1, 'conditional', False) or \
            getattr(self.conv_block_1x1_in, 'conditional', False) or \
            getattr(self.conv_block_1x1_out, 'conditional', False)

    def _get_stride_blur(self):
        if self.stride > 1:
            # Downsampling.
            first_stride, second_stride = 1, self.stride
            first_blur, second_blur = False, self.blur
            shortcut_blur = False
            shortcut_stride = 1
            if self.blur:
                # The shortcut branch uses blur_downsample + stride-1 conv
                if self.border_free:
                    self.resample = nn.AvgPool2d(2)
                else:
                    self.resample = BlurDownsample()
            else:
                shortcut_stride = self.stride
                self.resample = nn.AvgPool2d(2)
        elif self.stride < 1:
            # Upsampling.
            first_stride, second_stride = self.stride, 1
            first_blur, second_blur = self.blur, False
            shortcut_blur = False
            shortcut_stride = 1
            if self.blur:
                # The shortcut branch uses blur_upsample + stride-1 conv
                if self.border_free:
                    self.resample = nn.Upsample(scale_factor=2,
                                                mode='bilinear')
                else:
                    self.resample = BlurUpsample()
            else:
                shortcut_stride = self.stride
                self.resample = nn.Upsample(scale_factor=2)
        else:
            first_stride = second_stride = 1
            first_blur = second_blur = False
            shortcut_stride = 1
            shortcut_blur = False
            self.resample = None
        return (first_stride, second_stride, shortcut_stride,
                first_blur, second_blur, shortcut_blur)

    def conv_blocks(
            self, x, *cond_inputs, separate_cond=False, **kw_cond_inputs
    ):
        if separate_cond:
            assert len(list(cond_inputs)) == 4
            dx = self.conv_block_1x1_in(x, cond_inputs[0],
                                        **kw_cond_inputs.get('kwargs_0', {}))
            dx = self.conv_block_0(dx, cond_inputs[1],
                                   **kw_cond_inputs.get('kwargs_1', {}))
            dx = self.conv_block_1(dx, cond_inputs[2],
                                   **kw_cond_inputs.get('kwargs_2', {}))
            dx = self.conv_block_1x1_out(dx, cond_inputs[3],
                                         **kw_cond_inputs.get('kwargs_3', {}))
        else:
            dx = self.conv_block_1x1_in(x, *cond_inputs, **kw_cond_inputs)
            dx = self.conv_block_0(dx, *cond_inputs, **kw_cond_inputs)
            dx = self.conv_block_1(dx, *cond_inputs, **kw_cond_inputs)
            dx = self.conv_block_1x1_out(dx, *cond_inputs, **kw_cond_inputs)
        return dx

    def forward(self, x, *cond_inputs, do_checkpoint=False, **kw_cond_inputs):
        if do_checkpoint:
            dx = checkpoint(self.conv_blocks, x, *cond_inputs, **kw_cond_inputs)
        else:
            dx = self.conv_blocks(x, *cond_inputs, **kw_cond_inputs)

        if self.resample_first and self.resample is not None:
            x = self.resample(x)
        if self.learn_shortcut:
            x_shortcut = self.conv_block_s(
                x, *cond_inputs, **kw_cond_inputs
            )
        elif self.in_channels < self.out_channels:
            x_shortcut_pad = self.conv_block_s(
                x, *cond_inputs, **kw_cond_inputs
            )
            x_shortcut = torch.cat((x, x_shortcut_pad), dim=1)
        elif self.in_channels > self.out_channels:
            x_shortcut = x[:, :self.out_channels, :, :]
        else:
            x_shortcut = x
        if not self.resample_first and self.resample is not None:
            x_shortcut = self.resample(x_shortcut)

        output = x_shortcut + dx
        return self.output_scale * output

    def extra_repr(self):
        s = 'output_scale={output_scale}'
        return s.format(**self.__dict__)


class DeepRes2dBlock(_BaseDeepResBlock):
    r"""Residual block for 2D input.

    Args:
        in_channels (int) : Number of channels in the input tensor.
        out_channels (int) : Number of channels in the output tensor.
        kernel_size (int, optional, default=3): Kernel size for the
            convolutional filters in the residual link.
        padding (int, optional, default=1): Padding size.
        dilation (int, optional, default=1): Dilation factor.
        groups (int, optional, default=1): Number of convolutional/linear
            groups.
        padding_mode (string, optional, default='zeros'): Type of padding:
            ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
        weight_norm_type (str, optional, default='none'):
            Type of weight normalization.
            ``'none'``, ``'spectral'``, ``'weight'``
            or ``'weight_demod'``.
        weight_norm_params (obj, optional, default=None):
            Parameters of weight normalization.
            If not ``None``, ``weight_norm_params.__dict__`` will be used as
            keyword arguments when initializing weight normalization.
        activation_norm_type (str, optional, default='none'):
            Type of activation normalization.
            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``,
            ``'layer'``,  ``'layer_2d'``, ``'group'``, ``'adaptive'``,
            ``'spatially_adaptive'`` or ``'hyper_spatially_adaptive'``.
        activation_norm_params (obj, optional, default=None):
            Parameters of activation normalization.
            If not ``None``, ``activation_norm_params.__dict__`` will be used as
            keyword arguments when initializing activation normalization.
        skip_activation_norm (bool, optional, default=True): If ``True`` and
            ``learn_shortcut`` is also ``True``, applies activation norm to the
            learned shortcut connection.
        skip_nonlinearity (bool, optional, default=True): If ``True`` and
            ``learn_shortcut`` is also ``True``, applies nonlinearity to the
            learned shortcut connection.
        nonlinearity (str, optional, default='none'):
            Type of nonlinear activation function in the residual link.
            ``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``,
            ``'tanh'`` , ``'sigmoid'`` or ``'softmax'``.
        inplace_nonlinearity (bool, optional, default=False): If ``True``,
            set ``inplace=True`` when initializing the nonlinearity layers.
        apply_noise (bool, optional, default=False): If ``True``, adds
            Gaussian noise with learnable magnitude to the convolution output.
        hidden_channels_equal_out_channels (bool, optional, default=False):
            If ``True``, set the hidden channel number to be equal to the
            output channel number. If ``False``, the hidden channel number
            equals to the smaller of the input channel number and the
            output channel number.
        order (str, optional, default='CNACNA'): Order of operations
            in the residual link.
            ``'C'``: convolution,
            ``'N'``: normalization,
            ``'A'``: nonlinear activation.
        learn_shortcut (bool, optional, default=False): If ``True``, always use
            a convolutional shortcut instead of an identity one, otherwise only
            use a convolutional one if input and output have different number of
            channels.
    """

    def __init__(self, in_channels, out_channels, kernel_size=3,
                 stride=1, padding=1, dilation=1, groups=1, bias=True,
                 padding_mode='zeros',
                 weight_norm_type='none', weight_norm_params=None,
                 activation_norm_type='none', activation_norm_params=None,
                 skip_activation_norm=True, skip_nonlinearity=False,
                 skip_weight_norm=True,
                 nonlinearity='leakyrelu', inplace_nonlinearity=False,
                 apply_noise=False, hidden_channels_equal_out_channels=False,
                 order='CNACNA', learn_shortcut=False, output_scale=1,
                 blur=True, resample_first=True, border_free=False):
        super().__init__(in_channels, out_channels, kernel_size, stride,
                         padding, dilation, groups, bias, padding_mode,
                         weight_norm_type, weight_norm_params,
                         activation_norm_type, activation_norm_params,
                         skip_activation_norm, skip_nonlinearity, nonlinearity,
                         inplace_nonlinearity, apply_noise,
                         hidden_channels_equal_out_channels, order, Conv2dBlock,
                         learn_shortcut, output_scale, blur=blur,
                         resample_first=resample_first, border_free=border_free,
                         skip_weight_norm=skip_weight_norm)