File size: 15,618 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
from functools import partial

import numpy as np
import torch
import torch.nn as nn
from torch.nn import Upsample as NearestUpsample

from imaginaire.layers import Conv2dBlock, Res2dBlock
from imaginaire.utils.data import (get_paired_input_image_channel_number,
                                   get_paired_input_label_channel_number)
from imaginaire.utils.distributed import master_only_print as print


class Generator(nn.Module):
    r"""Pix2pixHD coarse-to-fine generator constructor.

    Args:
        gen_cfg (obj): Generator definition part of the yaml config file.
        data_cfg (obj): Data definition part of the yaml config file.
    """

    def __init__(self, gen_cfg, data_cfg):
        super().__init__()
        # pix2pixHD has a global generator.
        global_gen_cfg = gen_cfg.global_generator
        num_filters_global = getattr(global_gen_cfg, 'num_filters', 64)
        # Optionally, it can have several local enhancers. They are useful
        # for generating high resolution images.
        local_gen_cfg = gen_cfg.local_enhancer
        self.num_local_enhancers = num_local_enhancers = \
            getattr(local_gen_cfg, 'num_enhancers', 1)
        # By default, pix2pixHD using instance normalization.
        activation_norm_type = getattr(gen_cfg, 'activation_norm_type',
                                       'instance')
        activation_norm_params = getattr(gen_cfg, 'activation_norm_params',
                                         None)
        weight_norm_type = getattr(gen_cfg, 'weight_norm_type', '')
        padding_mode = getattr(gen_cfg, 'padding_mode', 'reflect')
        base_conv_block = partial(Conv2dBlock,
                                  padding_mode=padding_mode,
                                  weight_norm_type=weight_norm_type,
                                  activation_norm_type=activation_norm_type,
                                  activation_norm_params=activation_norm_params,
                                  nonlinearity='relu')
        base_res_block = partial(Res2dBlock,
                                 padding_mode=padding_mode,
                                 weight_norm_type=weight_norm_type,
                                 activation_norm_type=activation_norm_type,
                                 activation_norm_params=activation_norm_params,
                                 nonlinearity='relu', order='CNACN')
        # Know what is the number of available segmentation labels.
        num_input_channels = get_paired_input_label_channel_number(data_cfg)
        self.concat_features = False
        # Check whether label input contains specific type of data (e.g.
        # instance_maps).
        self.contain_instance_map = False
        if data_cfg.input_labels[-1] == 'instance_maps':
            self.contain_instance_map = True
        # The feature encoder is only useful when the instance map is provided.
        if hasattr(gen_cfg, 'enc') and self.contain_instance_map:
            num_feat_channels = getattr(gen_cfg.enc, 'num_feat_channels', 0)
            if num_feat_channels > 0:
                num_input_channels += num_feat_channels
                self.concat_features = True
                self.encoder = Encoder(gen_cfg.enc, data_cfg)

        # Global generator model.
        global_model = GlobalGenerator(global_gen_cfg, data_cfg,
                                       num_input_channels, padding_mode,
                                       base_conv_block, base_res_block)
        if num_local_enhancers == 0:
            self.global_model = global_model
        else:
            # Get rid of the last layer.
            global_model = global_model.model
            global_model = [global_model[i]
                            for i in range(len(global_model) - 1)]
            # global_model = [global_model[i]
            #                 for i in range(len(global_model) - 2)]
            self.global_model = nn.Sequential(*global_model)

        # Local enhancer model.
        for n in range(num_local_enhancers):
            # num_filters = num_filters_global // (2 ** n)
            num_filters = num_filters_global // (2 ** (n + 1))
            output_img = (n == num_local_enhancers - 1)
            setattr(self, 'enhancer_%d' % n,
                    LocalEnhancer(local_gen_cfg, data_cfg,
                                  num_input_channels, num_filters,
                                  padding_mode, base_conv_block,
                                  base_res_block, output_img))

        self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1],
                                       count_include_pad=False)

    def forward(self, data, random_style=False):
        r"""Coarse-to-fine generator forward.

        Args:
            data (dict) : Dictionary of input data.
            random_style (bool): Always set to false for the pix2pixHD model.
        Returns:
            output (dict) : Dictionary of output data.
        """
        label = data['label']

        output = dict()
        if self.concat_features:
            features = self.encoder(data['images'], data['instance_maps'])
            label = torch.cat([label, features], dim=1)
            output['feature_maps'] = features

        # Create input pyramid.
        input_downsampled = [label]
        for i in range(self.num_local_enhancers):
            input_downsampled.append(self.downsample(input_downsampled[-1]))

        # Output at coarsest level.
        x = self.global_model(input_downsampled[-1])

        # Coarse-to-fine: build up one layer at a time.
        for n in range(self.num_local_enhancers):
            input_n = input_downsampled[self.num_local_enhancers - n - 1]
            enhancer = getattr(self, 'enhancer_%d' % n)
            x = enhancer(x, input_n)

        output['fake_images'] = x
        return output

    def load_pretrained_network(self, pretrained_dict):
        r"""Load a pretrained network."""
        # print(pretrained_dict.keys())
        model_dict = self.state_dict()
        print('Pretrained network has fewer layers; The following are '
              'not initialized:')

        not_initialized = set()
        for k, v in model_dict.items():
            kp = 'module.' + k.replace('global_model.', 'global_model.model.')
            if kp in pretrained_dict and v.size() == pretrained_dict[kp].size():
                model_dict[k] = pretrained_dict[kp]
            else:
                not_initialized.add('.'.join(k.split('.')[:2]))
        print(sorted(not_initialized))
        self.load_state_dict(model_dict)

    def inference(self, data, **kwargs):
        r"""Generator inference.

        Args:
            data (dict) : Dictionary of input data.
        Returns:
            fake_images (tensor): Output fake images.
            file_names (str): Data file name.
        """
        output = self.forward(data, **kwargs)
        return output['fake_images'], data['key']['seg_maps'][0]


class LocalEnhancer(nn.Module):
    r"""Local enhancer constructor. These are sub-networks that are useful
    when aiming to produce high-resolution outputs.

    Args:
        gen_cfg (obj): local generator definition part of the yaml config
        file.
        data_cfg (obj): Data definition part of the yaml config file.
        num_input_channels (int): Number of segmentation labels.
        num_filters (int): Number of filters for the first layer.
        padding_mode (str): zero | reflect | ...
        base_conv_block (obj): Conv block with preset attributes.
        base_res_block (obj): Residual block with preset attributes.
        output_img (bool): Output is image or feature map.
    """

    def __init__(self, gen_cfg, data_cfg, num_input_channels, num_filters,
                 padding_mode, base_conv_block, base_res_block,
                 output_img=False):
        super(LocalEnhancer, self).__init__()
        num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 3)
        num_img_channels = get_paired_input_image_channel_number(data_cfg)
        # Downsample.
        model_downsample = \
            [base_conv_block(num_input_channels, num_filters, 7, padding=3),
             base_conv_block(num_filters, num_filters * 2, 3, stride=2,
                             padding=1)]
        # Residual blocks.
        model_upsample = []
        for i in range(num_res_blocks):
            model_upsample += [base_res_block(num_filters * 2, num_filters * 2,
                                              3, padding=1)]
        # Upsample.
        model_upsample += \
            [NearestUpsample(scale_factor=2),
             base_conv_block(num_filters * 2, num_filters, 3, padding=1)]

        # Final convolution.
        if output_img:
            model_upsample += [Conv2dBlock(num_filters, num_img_channels, 7,
                                           padding=3, padding_mode=padding_mode,
                                           nonlinearity='tanh')]

        self.model_downsample = nn.Sequential(*model_downsample)
        self.model_upsample = nn.Sequential(*model_upsample)

    def forward(self, output_coarse, input_fine):
        r"""Local enhancer forward.

        Args:
            output_coarse (4D tensor) : Coarse output from previous layer.
            input_fine (4D tensor) : Fine input from current layer.
        Returns:
            output (4D tensor) : Refined output.
        """
        output = self.model_upsample(self.model_downsample(input_fine) + output_coarse)
        return output


class GlobalGenerator(nn.Module):
    r"""Coarse generator constructor. This is the main generator in the
    pix2pixHD architecture.

    Args:
        gen_cfg (obj): Generator definition part of the yaml config file.
        data_cfg (obj): Data definition part of the yaml config file.
        num_input_channels (int): Number of segmentation labels.
        padding_mode (str): zero | reflect | ...
        base_conv_block (obj): Conv block with preset attributes.
        base_res_block (obj): Residual block with preset attributes.
    """

    def __init__(self, gen_cfg, data_cfg, num_input_channels, padding_mode,
                 base_conv_block, base_res_block):
        super(GlobalGenerator, self).__init__()
        num_img_channels = get_paired_input_image_channel_number(data_cfg)
        num_filters = getattr(gen_cfg, 'num_filters', 64)
        num_downsamples = getattr(gen_cfg, 'num_downsamples', 4)
        num_res_blocks = getattr(gen_cfg, 'num_res_blocks', 9)
        # First layer.
        model = [base_conv_block(num_input_channels, num_filters,
                                 kernel_size=7, padding=3)]
        # Downsample.
        for i in range(num_downsamples):
            ch = num_filters * (2 ** i)
            model += [base_conv_block(ch, ch * 2, 3, padding=1, stride=2)]
        # ResNet blocks.
        ch = num_filters * (2 ** num_downsamples)
        for i in range(num_res_blocks):
            model += [base_res_block(ch, ch, 3, padding=1)]
        # Upsample.
        num_upsamples = num_downsamples
        for i in reversed(range(num_upsamples)):
            ch = num_filters * (2 ** i)
            model += \
                [NearestUpsample(scale_factor=2),
                 base_conv_block(ch * 2, ch, 3, padding=1)]
        model += [Conv2dBlock(num_filters, num_img_channels, 7, padding=3,
                              padding_mode=padding_mode, nonlinearity='tanh')]
        self.model = nn.Sequential(*model)

    def forward(self, input):
        r"""Coarse-to-fine generator forward.

        Args:
            input (4D tensor) : Input semantic representations.
        Returns:
            output (4D tensor) : Synthesized image by generator.
        """
        return self.model(input)


class Encoder(nn.Module):
    r"""Encoder for getting region-wise features for style control.

    Args:
        enc_cfg (obj): Encoder definition part of the yaml config file.
        data_cfg (obj): Data definition part of the yaml config file
    """

    def __init__(self, enc_cfg, data_cfg):
        super(Encoder, self).__init__()
        label_nc = get_paired_input_label_channel_number(data_cfg)
        feat_nc = enc_cfg.num_feat_channels
        n_clusters = getattr(enc_cfg, 'num_clusters', 10)
        for i in range(label_nc):
            dummy_arr = np.zeros((n_clusters, feat_nc), dtype=np.float32)
            self.register_buffer('cluster_%d' % i,
                                 torch.tensor(dummy_arr, dtype=torch.float32))
        num_img_channels = get_paired_input_image_channel_number(data_cfg)
        self.num_feat_channels = getattr(enc_cfg, 'num_feat_channels', 3)
        num_filters = getattr(enc_cfg, 'num_filters', 64)
        num_downsamples = getattr(enc_cfg, 'num_downsamples', 4)
        weight_norm_type = getattr(enc_cfg, 'weight_norm_type', 'none')
        activation_norm_type = getattr(enc_cfg, 'activation_norm_type',
                                       'instance')
        padding_mode = getattr(enc_cfg, 'padding_mode', 'reflect')
        base_conv_block = partial(Conv2dBlock,
                                  padding_mode=padding_mode,
                                  weight_norm_type=weight_norm_type,
                                  activation_norm_type=activation_norm_type,
                                  nonlinearity='relu')
        model = [base_conv_block(num_img_channels, num_filters, 7, padding=3)]
        # Downsample.
        for i in range(num_downsamples):
            ch = num_filters * (2**i)
            model += [base_conv_block(ch, ch * 2, 3, stride=2, padding=1)]
        # Upsample.
        for i in reversed(range(num_downsamples)):
            ch = num_filters * (2 ** i)
            model += [NearestUpsample(scale_factor=2),
                      base_conv_block(ch * 2, ch, 3, padding=1)]

        model += [Conv2dBlock(num_filters, self.num_feat_channels, 7,
                              padding=3, padding_mode=padding_mode,
                              nonlinearity='tanh')]
        self.model = nn.Sequential(*model)

    def forward(self, input, instance_map):
        r"""Extracting region-wise features

        Args:
            input (4D tensor): Real RGB images.
            instance_map (4D tensor): Instance label mask.
        Returns:
            outputs_mean (4D tensor): Instance-wise average-pooled
                feature maps.
        """
        outputs = self.model(input)
        # Instance-wise average pooling.
        outputs_mean = torch.zeros_like(outputs)
        # Find all the unique labels in this batch.
        inst_list = np.unique(instance_map.cpu().numpy().astype(int))
        for i in inst_list:
            for b in range(input.size(0)):
                # Find the pixels in this instance map have this instance label.
                indices = (instance_map[b:b+1] == int(i)).nonzero()  # n x 4
                # Scan through the feature channels.
                for j in range(self.num_feat_channels):
                    output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j,
                                         indices[:, 2], indices[:, 3]]
                    mean_feat = torch.mean(output_ins).expand_as(output_ins)
                    outputs_mean[indices[:, 0] + b, indices[:, 1] + j,
                                 indices[:, 2], indices[:, 3]] = mean_feat
        return outputs_mean