File size: 40,214 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import json
import os
import time

import torch
import torchvision
import wandb
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

from imaginaire.utils.distributed import is_master, master_only
from imaginaire.utils.distributed import master_only_print as print
from imaginaire.utils.io import save_pilimage_in_jpeg
from imaginaire.utils.meters import Meter
from imaginaire.utils.misc import to_cuda, to_device, requires_grad, to_channels_last
from imaginaire.utils.model_average import (calibrate_batch_norm_momentum,
                                            reset_batch_norm)
from imaginaire.utils.visualization import tensor2pilimage


class BaseTrainer(object):
    r"""Base trainer. We expect that all trainers inherit this class.

    Args:
        cfg (obj): Global configuration.
        net_G (obj): Generator network.
        net_D (obj): Discriminator network.
        opt_G (obj): Optimizer for the generator network.
        opt_D (obj): Optimizer for the discriminator network.
        sch_G (obj): Scheduler for the generator optimizer.
        sch_D (obj): Scheduler for the discriminator optimizer.
        train_data_loader (obj): Train data loader.
        val_data_loader (obj): Validation data loader.
    """

    def __init__(self,
                 cfg,
                 net_G,
                 net_D,
                 opt_G,
                 opt_D,
                 sch_G,
                 sch_D,
                 train_data_loader,
                 val_data_loader):
        super(BaseTrainer, self).__init__()
        print('Setup trainer.')

        # Initialize models and data loaders.
        self.cfg = cfg
        self.net_G = net_G
        if cfg.trainer.model_average_config.enabled:
            # Two wrappers (DDP + model average).
            self.net_G_module = self.net_G.module.module
        else:
            # One wrapper (DDP)
            self.net_G_module = self.net_G.module
        self.val_data_loader = val_data_loader
        self.is_inference = train_data_loader is None
        self.net_D = net_D
        self.opt_G = opt_G
        self.opt_D = opt_D
        self.sch_G = sch_G
        self.sch_D = sch_D
        self.train_data_loader = train_data_loader
        if self.cfg.trainer.channels_last:
            self.net_G = self.net_G.to(memory_format=torch.channels_last)
            self.net_D = self.net_D.to(memory_format=torch.channels_last)

        # Initialize amp.
        if self.cfg.trainer.amp_config.enabled:
            print("Using automatic mixed precision training.")
        self.scaler_G = GradScaler(**vars(self.cfg.trainer.amp_config))
        self.scaler_D = GradScaler(**vars(self.cfg.trainer.amp_config))
        # In order to check whether the discriminator/generator has
        # skipped the last parameter update due to gradient overflow.
        self.last_step_count_G = 0
        self.last_step_count_D = 0
        self.skipped_G = False
        self.skipped_D = False

        # Initialize data augmentation policy.
        self.aug_policy = cfg.trainer.aug_policy
        print("Augmentation policy: {}".format(self.aug_policy))

        # Initialize loss functions.
        # All loss names have weights. Some have criterion modules.
        # Mapping from loss names to criterion modules.
        self.criteria = torch.nn.ModuleDict()
        # Mapping from loss names to loss weights.
        self.weights = dict()
        self.losses = dict(gen_update=dict(), dis_update=dict())
        self.gen_losses = self.losses['gen_update']
        self.dis_losses = self.losses['dis_update']
        self._init_loss(cfg)
        for loss_name, loss_weight in self.weights.items():
            print("Loss {:<20} Weight {}".format(loss_name, loss_weight))
            if loss_name in self.criteria.keys() and \
                    self.criteria[loss_name] is not None:
                self.criteria[loss_name].to('cuda')

        if self.is_inference:
            # The initialization steps below can be skipped during inference.
            return

        # Initialize logging attributes.
        self.current_iteration = 0
        self.current_epoch = 0
        self.start_iteration_time = None
        self.start_epoch_time = None
        self.elapsed_iteration_time = 0
        self.time_iteration = None
        self.time_epoch = None
        self.best_fid = None
        if self.cfg.speed_benchmark:
            self.accu_gen_forw_iter_time = 0
            self.accu_gen_loss_iter_time = 0
            self.accu_gen_back_iter_time = 0
            self.accu_gen_step_iter_time = 0
            self.accu_gen_avg_iter_time = 0
            self.accu_dis_forw_iter_time = 0
            self.accu_dis_loss_iter_time = 0
            self.accu_dis_back_iter_time = 0
            self.accu_dis_step_iter_time = 0

        # Initialize tensorboard and hparams.
        self._init_tensorboard()
        self._init_hparams()

        # Initialize validation parameters.
        self.val_sample_size = getattr(cfg.trainer, 'val_sample_size', 50000)
        self.kid_num_subsets = getattr(cfg.trainer, 'kid_num_subsets', 10)
        self.kid_subset_size = self.val_sample_size // self.kid_num_subsets
        self.metrics_path = os.path.join(torch.hub.get_dir(), 'metrics')
        self.best_metrics = {}
        self.eval_networks = getattr(cfg.trainer, 'eval_network', ['clean_inception'])
        if self.cfg.metrics_iter is None:
            self.cfg.metrics_iter = self.cfg.snapshot_save_iter
        if self.cfg.metrics_epoch is None:
            self.cfg.metrics_epoch = self.cfg.snapshot_save_epoch

        # AWS credentials.
        if hasattr(cfg, 'aws_credentials_file'):
            with open(cfg.aws_credentials_file) as fin:
                self.credentials = json.load(fin)
        else:
            self.credentials = None

        if 'TORCH_HOME' not in os.environ:
            os.environ['TORCH_HOME'] = os.path.join(
                os.environ['HOME'], ".cache")

    def _init_tensorboard(self):
        r"""Initialize the tensorboard. Different algorithms might require
        different performance metrics. Hence, custom tensorboard
        initialization might be necessary.
        """
        # Logging frequency: self.cfg.logging_iter
        self.meters = {}

        # Logging frequency: self.cfg.snapshot_save_iter
        self.metric_meters = {}

        # Logging frequency: self.cfg.image_display_iter
        self.image_meter = Meter('images', reduce=False)

    def _init_hparams(self):
        r"""Initialize a dictionary of hyperparameters that we want to monitor
        in the HParams dashboard in tensorBoard.
        """
        self.hparam_dict = {}

    def _write_tensorboard(self):
        r"""Write values to tensorboard. By default, we will log the time used
        per iteration, time used per epoch, generator learning rate, and
        discriminator learning rate. We will log all the losses as well as
        custom meters.
        """
        # Logs that are shared by all models.
        self._write_to_meters({'time/iteration': self.time_iteration,
                               'time/epoch': self.time_epoch,
                               'optim/gen_lr': self.sch_G.get_last_lr()[0],
                               'optim/dis_lr': self.sch_D.get_last_lr()[0]},
                              self.meters,
                              reduce=False)
        # Logs for loss values. Different models have different losses.
        self._write_loss_meters()
        # Other custom logs.
        self._write_custom_meters()

    def _write_loss_meters(self):
        r"""Write all loss values to tensorboard."""
        for update, losses in self.losses.items():
            # update is 'gen_update' or 'dis_update'.
            assert update == 'gen_update' or update == 'dis_update'
            for loss_name, loss in losses.items():
                if loss is not None:
                    full_loss_name = update + '/' + loss_name
                    if full_loss_name not in self.meters.keys():
                        # Create a new meter if it doesn't exist.
                        self.meters[full_loss_name] = Meter(
                            full_loss_name, reduce=True)
                    self.meters[full_loss_name].write(loss.item())

    def _write_custom_meters(self):
        r"""Dummy member function to be overloaded by the child class.
        In the child class, you can write down whatever you want to track.
        """
        pass

    @staticmethod
    def _write_to_meters(data, meters, reduce=True):
        r"""Write values to meters."""
        if reduce or is_master():
            for key, value in data.items():
                if key not in meters:
                    meters[key] = Meter(key, reduce=reduce)
                meters[key].write(value)

    def _flush_meters(self, meters):
        r"""Flush all meters using the current iteration."""
        for meter in meters.values():
            meter.flush(self.current_iteration)

    def _pre_save_checkpoint(self):
        r"""Implement the things you want to do before saving a checkpoint.
        For example, you can compute the K-mean features (pix2pixHD) before
        saving the model weights to a checkpoint.
        """
        pass

    def save_checkpoint(self, current_epoch, current_iteration):
        r"""Save network weights, optimizer parameters, scheduler parameters
        to a checkpoint.
        """
        self._pre_save_checkpoint()
        _save_checkpoint(self.cfg,
                         self.net_G, self.net_D,
                         self.opt_G, self.opt_D,
                         self.sch_G, self.sch_D,
                         current_epoch, current_iteration)

    def load_checkpoint(self, cfg, checkpoint_path, resume=None, load_sch=True):
        r"""Load network weights, optimizer parameters, scheduler parameters
        from a checkpoint.

        Args:
            cfg (obj): Global configuration.
            checkpoint_path (str): Path to the checkpoint.
            resume (bool or None): If not ``None``, will determine whether or
                not to load optimizers in addition to network weights.
        """
        if os.path.exists(checkpoint_path):
            # If checkpoint_path exists, we will load its weights to
            # initialize our network.
            if resume is None:
                resume = False
        elif os.path.exists(os.path.join(cfg.logdir, 'latest_checkpoint.txt')):
            # This is for resuming the training from the previously saved
            # checkpoint.
            fn = os.path.join(cfg.logdir, 'latest_checkpoint.txt')
            with open(fn, 'r') as f:
                line = f.read().splitlines()
            checkpoint_path = os.path.join(cfg.logdir, line[0].split(' ')[-1])
            if resume is None:
                resume = True
        else:
            # checkpoint not found and not specified. We will train
            # everything from scratch.
            current_epoch = 0
            current_iteration = 0
            print('No checkpoint found.')
            resume = False
            return resume, current_epoch, current_iteration
        # Load checkpoint
        checkpoint = torch.load(
            checkpoint_path, map_location=lambda storage, loc: storage)
        current_epoch = 0
        current_iteration = 0
        if resume:
            self.net_G.load_state_dict(checkpoint['net_G'], strict=self.cfg.trainer.strict_resume)
            if not self.is_inference:
                self.net_D.load_state_dict(checkpoint['net_D'], strict=self.cfg.trainer.strict_resume)
                if 'opt_G' in checkpoint:
                    current_epoch = checkpoint['current_epoch']
                    current_iteration = checkpoint['current_iteration']
                    self.opt_G.load_state_dict(checkpoint['opt_G'])
                    self.opt_D.load_state_dict(checkpoint['opt_D'])
                    if load_sch:
                        self.sch_G.load_state_dict(checkpoint['sch_G'])
                        self.sch_D.load_state_dict(checkpoint['sch_D'])
                    else:
                        if self.cfg.gen_opt.lr_policy.iteration_mode:
                            self.sch_G.last_epoch = current_iteration
                        else:
                            self.sch_G.last_epoch = current_epoch
                        if self.cfg.dis_opt.lr_policy.iteration_mode:
                            self.sch_D.last_epoch = current_iteration
                        else:
                            self.sch_D.last_epoch = current_epoch
                    print('Load from: {}'.format(checkpoint_path))
                else:
                    print('Load network weights only.')
        else:
            try:
                self.net_G.load_state_dict(checkpoint['net_G'], strict=self.cfg.trainer.strict_resume)
                if 'net_D' in checkpoint:
                    self.net_D.load_state_dict(checkpoint['net_D'], strict=self.cfg.trainer.strict_resume)
            except Exception:
                if self.cfg.trainer.model_average_config.enabled:
                    net_G_module = self.net_G.module.module
                else:
                    net_G_module = self.net_G.module
                if hasattr(net_G_module, 'load_pretrained_network'):
                    net_G_module.load_pretrained_network(self.net_G, checkpoint['net_G'])
                    print('Load generator weights only.')
                else:
                    raise ValueError('Checkpoint cannot be loaded.')

        print('Done with loading the checkpoint.')
        return resume, current_epoch, current_iteration

    def start_of_epoch(self, current_epoch):
        r"""Things to do before an epoch.

        Args:
            current_epoch (int): Current number of epoch.
        """
        self._start_of_epoch(current_epoch)
        self.current_epoch = current_epoch
        self.start_epoch_time = time.time()

    def start_of_iteration(self, data, current_iteration):
        r"""Things to do before an iteration.

        Args:
            data (dict): Data used for the current iteration.
            current_iteration (int): Current number of iteration.
        """
        data = self._start_of_iteration(data, current_iteration)
        data = to_cuda(data)
        if self.cfg.trainer.channels_last:
            data = to_channels_last(data)
        self.current_iteration = current_iteration
        if not self.is_inference:
            self.net_D.train()
        self.net_G.train()
        # torch.cuda.synchronize()
        self.start_iteration_time = time.time()
        return data

    def end_of_iteration(self, data, current_epoch, current_iteration):
        r"""Things to do after an iteration.

        Args:
            data (dict): Data used for the current iteration.
            current_epoch (int): Current number of epoch.
            current_iteration (int): Current number of iteration.
        """
        self.current_iteration = current_iteration
        self.current_epoch = current_epoch
        # Update the learning rate policy for the generator if operating in the
        # iteration mode.
        if self.cfg.gen_opt.lr_policy.iteration_mode:
            self.sch_G.step()
        # Update the learning rate policy for the discriminator if operating in
        # the iteration mode.
        if self.cfg.dis_opt.lr_policy.iteration_mode:
            self.sch_D.step()

        # Accumulate time
        # torch.cuda.synchronize()
        self.elapsed_iteration_time += time.time() - self.start_iteration_time
        # Logging.
        if current_iteration % self.cfg.logging_iter == 0:
            ave_t = self.elapsed_iteration_time / self.cfg.logging_iter
            self.time_iteration = ave_t
            print('Iteration: {}, average iter time: '
                  '{:6f}.'.format(current_iteration, ave_t))
            self.elapsed_iteration_time = 0

            if self.cfg.speed_benchmark:
                # Below code block only needed when analyzing computation
                # bottleneck.
                print('\tGenerator FWD time {:6f}'.format(
                    self.accu_gen_forw_iter_time / self.cfg.logging_iter))
                print('\tGenerator LOS time {:6f}'.format(
                    self.accu_gen_loss_iter_time / self.cfg.logging_iter))
                print('\tGenerator BCK time {:6f}'.format(
                    self.accu_gen_back_iter_time / self.cfg.logging_iter))
                print('\tGenerator STP time {:6f}'.format(
                    self.accu_gen_step_iter_time / self.cfg.logging_iter))
                print('\tGenerator AVG time {:6f}'.format(
                    self.accu_gen_avg_iter_time / self.cfg.logging_iter))

                print('\tDiscriminator FWD time {:6f}'.format(
                    self.accu_dis_forw_iter_time / self.cfg.logging_iter))
                print('\tDiscriminator LOS time {:6f}'.format(
                    self.accu_dis_loss_iter_time / self.cfg.logging_iter))
                print('\tDiscriminator BCK time {:6f}'.format(
                    self.accu_dis_back_iter_time / self.cfg.logging_iter))
                print('\tDiscriminator STP time {:6f}'.format(
                    self.accu_dis_step_iter_time / self.cfg.logging_iter))

                print('{:6f}'.format(ave_t))

                self.accu_gen_forw_iter_time = 0
                self.accu_gen_loss_iter_time = 0
                self.accu_gen_back_iter_time = 0
                self.accu_gen_step_iter_time = 0
                self.accu_gen_avg_iter_time = 0
                self.accu_dis_forw_iter_time = 0
                self.accu_dis_loss_iter_time = 0
                self.accu_dis_back_iter_time = 0
                self.accu_dis_step_iter_time = 0

        self._end_of_iteration(data, current_epoch, current_iteration)

        # Save everything to the checkpoint.
        if current_iteration % self.cfg.snapshot_save_iter == 0:
            if current_iteration >= self.cfg.snapshot_save_start_iter:
                self.save_checkpoint(current_epoch, current_iteration)

        # Compute metrics.
        if current_iteration % self.cfg.metrics_iter == 0:
            self.save_image(self._get_save_path('images', 'jpg'), data)
            self.write_metrics()

        # Compute image to be saved.
        elif current_iteration % self.cfg.image_save_iter == 0:
            self.save_image(self._get_save_path('images', 'jpg'), data)
        elif current_iteration % self.cfg.image_display_iter == 0:
            image_path = os.path.join(self.cfg.logdir, 'images', 'current.jpg')
            self.save_image(image_path, data)

        # Logging.
        self._write_tensorboard()
        if current_iteration % self.cfg.logging_iter == 0:
            # Write all logs to tensorboard.
            self._flush_meters(self.meters)

        from torch.distributed import barrier
        import torch.distributed as dist
        if dist.is_initialized():
            barrier()

    def end_of_epoch(self, data, current_epoch, current_iteration):
        r"""Things to do after an epoch.

        Args:
            data (dict): Data used for the current iteration.

            current_epoch (int): Current number of epoch.
            current_iteration (int): Current number of iteration.
        """
        # Update the learning rate policy for the generator if operating in the
        # epoch mode.
        self.current_iteration = current_iteration
        self.current_epoch = current_epoch
        if not self.cfg.gen_opt.lr_policy.iteration_mode:
            self.sch_G.step()
        # Update the learning rate policy for the discriminator if operating
        # in the epoch mode.
        if not self.cfg.dis_opt.lr_policy.iteration_mode:
            self.sch_D.step()
        elapsed_epoch_time = time.time() - self.start_epoch_time
        # Logging.
        print('Epoch: {}, total time: {:6f}.'.format(current_epoch,
                                                     elapsed_epoch_time))
        self.time_epoch = elapsed_epoch_time
        self._end_of_epoch(data, current_epoch, current_iteration)

        # Save everything to the checkpoint.
        if current_iteration % self.cfg.snapshot_save_iter == 0:
            if current_epoch >= self.cfg.snapshot_save_start_epoch:
                self.save_checkpoint(current_epoch, current_iteration)

        # Compute metrics.
        if current_iteration % self.cfg.metrics_iter == 0:
            self.save_image(self._get_save_path('images', 'jpg'), data)
            self.write_metrics()

    def pre_process(self, data):
        r"""Custom data pre-processing function. Utilize this function if you
        need to preprocess your data before sending it to the generator and
        discriminator.

        Args:
            data (dict): Data used for the current iteration.
        """

    def recalculate_batch_norm_statistics(self, data_loader, averaged=True):
        r"""Update the statistics in the moving average model.

        Args:
            data_loader (torch.utils.data.DataLoader): Data loader for
                estimating the statistics.
            averaged (Boolean): True/False, we recalculate batch norm statistics for EMA/regular
        """
        if not self.cfg.trainer.model_average_config.enabled:
            return
        if averaged:
            net_G = self.net_G.module.averaged_model
        else:
            net_G = self.net_G_module
        model_average_iteration = \
            self.cfg.trainer.model_average_config.num_batch_norm_estimation_iterations
        if model_average_iteration == 0:
            return
        with torch.no_grad():
            # Accumulate bn stats..
            net_G.train()
            # Reset running stats.
            net_G.apply(reset_batch_norm)
            for cal_it, cal_data in enumerate(data_loader):
                if cal_it >= model_average_iteration:
                    print('Done with {} iterations of updating batch norm '
                          'statistics'.format(model_average_iteration))
                    break
                cal_data = to_device(cal_data, 'cuda')
                cal_data = self.pre_process(cal_data)
                # Averaging over all batches
                net_G.apply(calibrate_batch_norm_momentum)
                net_G(cal_data)

    def save_image(self, path, data):
        r"""Compute visualization images and save them to the disk.

        Args:
            path (str): Location of the file.
            data (dict): Data used for the current iteration.
        """
        self.net_G.eval()
        vis_images = self._get_visualizations(data)
        if is_master() and vis_images is not None:
            vis_images = torch.cat(
                [img for img in vis_images if img is not None], dim=3).float()
            vis_images = (vis_images + 1) / 2
            print('Save output images to {}'.format(path))
            vis_images.clamp_(0, 1)
            os.makedirs(os.path.dirname(path), exist_ok=True)
            image_grid = torchvision.utils.make_grid(
                vis_images, nrow=1, padding=0, normalize=False)
            if self.cfg.trainer.image_to_tensorboard:
                self.image_meter.write_image(image_grid, self.current_iteration)
            torchvision.utils.save_image(image_grid, path, nrow=1)
            wandb.log({os.path.splitext(os.path.basename(path))[0]: [wandb.Image(path)]})

    def write_metrics(self):
        r"""Write metrics to the tensorboard."""
        cur_fid = self._compute_fid()
        if cur_fid is not None:
            if self.best_fid is not None:
                self.best_fid = min(self.best_fid, cur_fid)
            else:
                self.best_fid = cur_fid
            metric_dict = {'FID': cur_fid, 'best_FID': self.best_fid}
            self._write_to_meters(metric_dict, self.metric_meters, reduce=False)
            self._flush_meters(self.metric_meters)

    def _get_save_path(self, subdir, ext):
        r"""Get the image save path.

        Args:
            subdir (str): Sub-directory under the main directory for saving
                the outputs.
            ext (str): Filename extension for the image (e.g., jpg, png, ...).
        Return:
            (str): image filename to be used to save the visualization results.
        """
        subdir_path = os.path.join(self.cfg.logdir, subdir)
        if not os.path.exists(subdir_path):
            os.makedirs(subdir_path, exist_ok=True)
        return os.path.join(
            subdir_path, 'epoch_{:05}_iteration_{:09}.{}'.format(
                self.current_epoch, self.current_iteration, ext))

    def _get_outputs(self, net_D_output, real=True):
        r"""Return output values. Note that when the gan mode is relativistic.
        It will do the difference before returning.

        Args:
           net_D_output (dict):
               real_outputs (tensor): Real output values.
               fake_outputs (tensor): Fake output values.
           real (bool): Return real or fake.
        """

        def _get_difference(a, b):
            r"""Get difference between two lists of tensors or two tensors.

            Args:
                a: list of tensors or tensor
                b: list of tensors or tensor
            """
            out = list()
            for x, y in zip(a, b):
                if isinstance(x, list):
                    res = _get_difference(x, y)
                else:
                    res = x - y
                out.append(res)
            return out

        if real:
            if self.cfg.trainer.gan_relativistic:
                return _get_difference(net_D_output['real_outputs'], net_D_output['fake_outputs'])
            else:
                return net_D_output['real_outputs']
        else:
            if self.cfg.trainer.gan_relativistic:
                return _get_difference(net_D_output['fake_outputs'], net_D_output['real_outputs'])
            else:
                return net_D_output['fake_outputs']

    def _start_of_epoch(self, current_epoch):
        r"""Operations to do before starting an epoch.

        Args:
            current_epoch (int): Current number of epoch.
        """
        pass

    def _start_of_iteration(self, data, current_iteration):
        r"""Operations to do before starting an iteration.

        Args:
            data (dict): Data used for the current iteration.
            current_iteration (int): Current epoch number.
        Returns:
            (dict): Data used for the current iteration. They might be
                processed by the custom _start_of_iteration function.
        """
        return data

    def _end_of_iteration(self, data, current_epoch, current_iteration):
        r"""Operations to do after an iteration.

        Args:
            data (dict): Data used for the current iteration.
            current_epoch (int): Current number of epoch.
            current_iteration (int): Current epoch number.
        """
        pass

    def _end_of_epoch(self, data, current_epoch, current_iteration):
        r"""Operations to do after an epoch.

        Args:
            data (dict): Data used for the current iteration.
            current_epoch (int): Current number of epoch.
            current_iteration (int): Current epoch number.
        """
        pass

    def _get_visualizations(self, data):
        r"""Compute visualization outputs.

        Args:
            data (dict): Data used for the current iteration.
        """
        return None

    def _compute_fid(self):
        r"""FID computation function to be overloaded."""
        return None

    def _init_loss(self, cfg):
        r"""Every trainer should implement its own init loss function."""
        raise NotImplementedError

    def gen_update(self, data):
        r"""Update the generator.

        Args:
            data (dict): Data used for the current iteration.
        """
        update_finished = False
        while not update_finished:
            # Set requires_grad flags.
            requires_grad(self.net_G_module, True)
            requires_grad(self.net_D, False)

            # Compute the loss.
            self._time_before_forward()
            with autocast(enabled=self.cfg.trainer.amp_config.enabled):
                total_loss = self.gen_forward(data)
            if total_loss is None:
                return

            # Zero-grad and backpropagate the loss.
            self.opt_G.zero_grad(set_to_none=True)
            self._time_before_backward()
            self.scaler_G.scale(total_loss).backward()

            # Optionally clip gradient norm.
            if hasattr(self.cfg.gen_opt, 'clip_grad_norm'):
                self.scaler_G.unscale_(self.opt_G)
                total_norm = torch.nn.utils.clip_grad_norm_(
                    self.net_G_module.parameters(),
                    self.cfg.gen_opt.clip_grad_norm
                )
                self.gen_grad_norm = total_norm
                if torch.isfinite(total_norm) and \
                        total_norm > self.cfg.gen_opt.clip_grad_norm:
                    # print(f"Gradient norm of the generator ({total_norm}) "
                    #       f"too large.")
                    if getattr(self.cfg.gen_opt, 'skip_grad', False):
                        print(f"Skip gradient update.")
                        self.opt_G.zero_grad(set_to_none=True)
                        self.scaler_G.step(self.opt_G)
                        self.scaler_G.update()
                        break
                    # else:
                    #     print(f"Clip gradient norm to "
                    #           f"{self.cfg.gen_opt.clip_grad_norm}.")

            # Perform an optimizer step.
            self._time_before_step()
            self.scaler_G.step(self.opt_G)
            self.scaler_G.update()
            # Whether the step above was skipped.
            if self.last_step_count_G == self.opt_G._step_count:
                print("Generator overflowed!")
                if not torch.isfinite(total_loss):
                    print("Generator loss is not finite. Skip this iteration!")
                    update_finished = True
            else:
                self.last_step_count_G = self.opt_G._step_count
                update_finished = True

        self._extra_gen_step(data)

        # Update model average.
        self._time_before_model_avg()
        if self.cfg.trainer.model_average_config.enabled:
            self.net_G.module.update_average()

        self._detach_losses()
        self._time_before_leave_gen()

    def gen_forward(self, data):
        r"""Every trainer should implement its own generator forward."""
        raise NotImplementedError

    def _extra_gen_step(self, data):
        pass

    def dis_update(self, data):
        r"""Update the discriminator.

        Args:
            data (dict): Data used for the current iteration.
        """
        update_finished = False
        while not update_finished:
            # Set requires_grad flags.
            requires_grad(self.net_G_module, False)
            requires_grad(self.net_D, True)

            # Compute the loss.
            self._time_before_forward()
            with autocast(enabled=self.cfg.trainer.amp_config.enabled):
                total_loss = self.dis_forward(data)
            if total_loss is None:
                return

            # Zero-grad and backpropagate the loss.
            self.opt_D.zero_grad(set_to_none=True)
            self._time_before_backward()
            self.scaler_D.scale(total_loss).backward()

            # Optionally clip gradient norm.
            if hasattr(self.cfg.dis_opt, 'clip_grad_norm'):
                self.scaler_D.unscale_(self.opt_D)
                total_norm = torch.nn.utils.clip_grad_norm_(
                    self.net_D.parameters(), self.cfg.dis_opt.clip_grad_norm
                )
                self.dis_grad_norm = total_norm
                if torch.isfinite(total_norm) and \
                        total_norm > self.cfg.dis_opt.clip_grad_norm:
                    print(f"Gradient norm of the discriminator ({total_norm}) "
                          f"too large.")
                    if getattr(self.cfg.dis_opt, 'skip_grad', False):
                        print(f"Skip gradient update.")
                        self.opt_D.zero_grad(set_to_none=True)
                        self.scaler_D.step(self.opt_D)
                        self.scaler_D.update()
                        continue
                    else:
                        print(f"Clip gradient norm to "
                              f"{self.cfg.dis_opt.clip_grad_norm}.")

            # Perform an optimizer step.
            self._time_before_step()
            self.scaler_D.step(self.opt_D)
            self.scaler_D.update()
            # Whether the step above was skipped.
            if self.last_step_count_D == self.opt_D._step_count:
                print("Discriminator overflowed!")
                if not torch.isfinite(total_loss):
                    print("Discriminator loss is not finite. "
                          "Skip this iteration!")
                    update_finished = True
            else:
                self.last_step_count_D = self.opt_D._step_count
                update_finished = True

        self._extra_dis_step(data)

        self._detach_losses()
        self._time_before_leave_dis()

    def dis_forward(self, data):
        r"""Every trainer should implement its own discriminator forward."""
        raise NotImplementedError

    def _extra_dis_step(self, data):
        pass

    def test(self, data_loader, output_dir, inference_args):
        r"""Compute results images for a batch of input data and save the
        results in the specified folder.

        Args:
            data_loader (torch.utils.data.DataLoader): PyTorch dataloader.
            output_dir (str): Target location for saving the output image.
        """
        if self.cfg.trainer.model_average_config.enabled:
            net_G = self.net_G.module.averaged_model
        else:
            net_G = self.net_G.module
        net_G.eval()

        print('# of samples %d' % len(data_loader))
        for it, data in enumerate(tqdm(data_loader)):
            data = self.start_of_iteration(data, current_iteration=-1)
            with torch.no_grad():
                output_images, file_names = \
                    net_G.inference(data, **vars(inference_args))
            for output_image, file_name in zip(output_images, file_names):
                fullname = os.path.join(output_dir, file_name + '.jpg')
                output_image = tensor2pilimage(output_image.clamp_(-1, 1),
                                               minus1to1_normalized=True)
                save_pilimage_in_jpeg(fullname, output_image)

    def _get_total_loss(self, gen_forward):
        r"""Return the total loss to be backpropagated.
        Args:
            gen_forward (bool): If ``True``, backpropagates the generator loss,
                otherwise the discriminator loss.
        """
        losses = self.gen_losses if gen_forward else self.dis_losses
        total_loss = torch.tensor(0., device=torch.device('cuda'))
        # Iterates over all possible losses.
        for loss_name in self.weights:
            # If it is for the current model (gen/dis).
            if loss_name in losses:
                # Multiply it with the corresponding weight
                # and add it to the total loss.
                total_loss += losses[loss_name] * self.weights[loss_name]
        losses['total'] = total_loss  # logging purpose
        return total_loss

    def _detach_losses(self):
        r"""Detach all logging variables to prevent potential memory leak."""
        for loss_name in self.gen_losses:
            self.gen_losses[loss_name] = self.gen_losses[loss_name].detach()
        for loss_name in self.dis_losses:
            self.dis_losses[loss_name] = self.dis_losses[loss_name].detach()

    def _time_before_forward(self):
        r"""
        Record time before applying forward.
        """
        if self.cfg.speed_benchmark:
            torch.cuda.synchronize()
            self.forw_time = time.time()

    def _time_before_loss(self):
        r"""
        Record time before computing loss.
        """
        if self.cfg.speed_benchmark:
            torch.cuda.synchronize()
            self.loss_time = time.time()

    def _time_before_backward(self):
        r"""
        Record time before applying backward.
        """
        if self.cfg.speed_benchmark:
            torch.cuda.synchronize()
            self.back_time = time.time()

    def _time_before_step(self):
        r"""
        Record time before updating the weights
        """
        if self.cfg.speed_benchmark:
            torch.cuda.synchronize()
            self.step_time = time.time()

    def _time_before_model_avg(self):
        r"""
        Record time before applying model average.
        """
        if self.cfg.speed_benchmark:
            torch.cuda.synchronize()
            self.avg_time = time.time()

    def _time_before_leave_gen(self):
        r"""
        Record forward, backward, loss, and model average time for the
        generator update.
        """
        if self.cfg.speed_benchmark:
            torch.cuda.synchronize()
            end_time = time.time()
            self.accu_gen_forw_iter_time += self.loss_time - self.forw_time
            self.accu_gen_loss_iter_time += self.back_time - self.loss_time
            self.accu_gen_back_iter_time += self.step_time - self.back_time
            self.accu_gen_step_iter_time += self.avg_time - self.step_time
            self.accu_gen_avg_iter_time += end_time - self.avg_time

    def _time_before_leave_dis(self):
        r"""
        Record forward, backward, loss time for the discriminator update.
        """
        if self.cfg.speed_benchmark:
            torch.cuda.synchronize()
            end_time = time.time()
            self.accu_dis_forw_iter_time += self.loss_time - self.forw_time
            self.accu_dis_loss_iter_time += self.back_time - self.loss_time
            self.accu_dis_back_iter_time += self.step_time - self.back_time
            self.accu_dis_step_iter_time += end_time - self.step_time


@master_only
def _save_checkpoint(cfg,
                     net_G, net_D,
                     opt_G, opt_D,
                     sch_G, sch_D,
                     current_epoch, current_iteration):
    r"""Save network weights, optimizer parameters, scheduler parameters
    in the checkpoint.

    Args:
        cfg (obj): Global configuration.
        net_D (obj): Discriminator network.
        opt_G (obj): Optimizer for the generator network.
        opt_D (obj): Optimizer for the discriminator network.
        sch_G (obj): Scheduler for the generator optimizer.
        sch_D (obj): Scheduler for the discriminator optimizer.
        current_epoch (int): Current epoch.
        current_iteration (int): Current iteration.
    """
    latest_checkpoint_path = 'epoch_{:05}_iteration_{:09}_checkpoint.pt'.format(
        current_epoch, current_iteration)
    save_path = os.path.join(cfg.logdir, latest_checkpoint_path)
    torch.save(
        {
            'net_G': net_G.state_dict(),
            'net_D': net_D.state_dict(),
            'opt_G': opt_G.state_dict(),
            'opt_D': opt_D.state_dict(),
            'sch_G': sch_G.state_dict(),
            'sch_D': sch_D.state_dict(),
            'current_epoch': current_epoch,
            'current_iteration': current_iteration,
        },
        save_path,
    )
    fn = os.path.join(cfg.logdir, 'latest_checkpoint.txt')
    with open(fn, 'wt') as f:
        f.write('latest_checkpoint: %s' % latest_checkpoint_path)
    print('Save checkpoint to {}'.format(save_path))
    return save_path