Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import os | |
from torch.cuda.amp import autocast | |
import imageio | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
from imaginaire.evaluation.fid import compute_fid | |
from imaginaire.losses import (FeatureMatchingLoss, FlowLoss, GANLoss, | |
PerceptualLoss) | |
from imaginaire.model_utils.fs_vid2vid import (concat_frames, detach, | |
get_fg_mask, | |
pre_process_densepose, resample) | |
from imaginaire.trainers.base import BaseTrainer | |
from imaginaire.utils.distributed import is_master | |
from imaginaire.utils.distributed import master_only_print as print | |
from imaginaire.utils.misc import get_nested_attr, split_labels, to_cuda | |
from imaginaire.utils.visualization import (tensor2flow, tensor2im, tensor2label) | |
from imaginaire.utils.visualization.pose import tensor2pose | |
class Trainer(BaseTrainer): | |
r"""Initialize vid2vid trainer. | |
Args: | |
cfg (obj): Global configuration. | |
net_G (obj): Generator network. | |
net_D (obj): Discriminator network. | |
opt_G (obj): Optimizer for the generator network. | |
opt_D (obj): Optimizer for the discriminator network. | |
sch_G (obj): Scheduler for the generator optimizer. | |
sch_D (obj): Scheduler for the discriminator optimizer. | |
train_data_loader (obj): Train data loader. | |
val_data_loader (obj): Validation data loader. | |
""" | |
def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, | |
train_data_loader, val_data_loader): | |
super(Trainer, self).__init__(cfg, net_G, net_D, opt_G, | |
opt_D, sch_G, sch_D, | |
train_data_loader, val_data_loader) | |
# Below is for testing setting, the FID computation during training | |
# is just for getting a quick idea of the performance. It does not | |
# equal to the final performance evaluation. | |
# Below, we will determine how many videos that we want to do | |
# evaluation, and the length of each video. | |
# It is better to keep the number of videos to be multiple of 8 so | |
# that all the GPUs in a node will contribute equally to the | |
# evaluation. None of them is idol. | |
self.sample_size = ( | |
getattr(cfg.trainer, 'num_videos_to_test', 64), | |
getattr(cfg.trainer, 'num_frames_per_video', 10) | |
) | |
self.sequence_length = 1 | |
if not self.is_inference: | |
self.train_dataset = self.train_data_loader.dataset | |
self.sequence_length_max = \ | |
min(getattr(cfg.data.train, 'max_sequence_length', 100), | |
self.train_dataset.sequence_length_max) | |
self.Tensor = torch.cuda.FloatTensor | |
self.has_fg = getattr(cfg.data, 'has_foreground', False) | |
self.net_G_output = self.data_prev = None | |
self.net_G_module = self.net_G.module | |
if self.cfg.trainer.model_average_config.enabled: | |
self.net_G_module = self.net_G_module.module | |
def _assign_criteria(self, name, criterion, weight): | |
r"""Assign training loss terms. | |
Args: | |
name (str): Loss name | |
criterion (obj): Loss object. | |
weight (float): Loss weight. It should be non-negative. | |
""" | |
self.criteria[name] = criterion | |
self.weights[name] = weight | |
def _init_loss(self, cfg): | |
r"""Initialize training loss terms. In vid2vid, in addition to | |
the GAN loss, feature matching loss, and perceptual loss used in | |
pix2pixHD, we also add temporal GAN (and feature matching) loss, | |
and flow warping loss. Optionally, we can also add an additional | |
face discriminator for the face region. | |
Args: | |
cfg (obj): Global configuration. | |
""" | |
self.criteria = dict() | |
self.weights = dict() | |
trainer_cfg = cfg.trainer | |
loss_weight = cfg.trainer.loss_weight | |
# GAN loss and feature matching loss. | |
self._assign_criteria('GAN', | |
GANLoss(trainer_cfg.gan_mode), | |
loss_weight.gan) | |
self._assign_criteria('FeatureMatching', | |
FeatureMatchingLoss(), | |
loss_weight.feature_matching) | |
# Perceptual loss. | |
perceptual_loss = cfg.trainer.perceptual_loss | |
self._assign_criteria('Perceptual', | |
PerceptualLoss( | |
network=perceptual_loss.mode, | |
layers=perceptual_loss.layers, | |
weights=perceptual_loss.weights, | |
num_scales=getattr(perceptual_loss, | |
'num_scales', 1)), | |
loss_weight.perceptual) | |
# L1 Loss. | |
if getattr(loss_weight, 'L1', 0) > 0: | |
self._assign_criteria('L1', torch.nn.L1Loss(), loss_weight.L1) | |
# Whether to add an additional discriminator for specific regions. | |
self.add_dis_cfg = getattr(self.cfg.dis, 'additional_discriminators', | |
None) | |
if self.add_dis_cfg is not None: | |
for name in self.add_dis_cfg: | |
add_dis_cfg = self.add_dis_cfg[name] | |
self.weights['GAN_' + name] = add_dis_cfg.loss_weight | |
self.weights['FeatureMatching_' + name] = \ | |
loss_weight.feature_matching | |
# Temporal GAN loss. | |
self.num_temporal_scales = get_nested_attr(self.cfg.dis, | |
'temporal.num_scales', 0) | |
for s in range(self.num_temporal_scales): | |
self.weights['GAN_T%d' % s] = loss_weight.temporal_gan | |
self.weights['FeatureMatching_T%d' % s] = \ | |
loss_weight.feature_matching | |
# Flow loss. It consists of three parts: L1 loss compared to GT, | |
# warping loss when used to warp images, and loss on the occlusion mask. | |
self.use_flow = hasattr(cfg.gen, 'flow') | |
if self.use_flow: | |
self.criteria['Flow'] = FlowLoss(cfg) | |
self.weights['Flow'] = self.weights['Flow_L1'] = \ | |
self.weights['Flow_Warp'] = \ | |
self.weights['Flow_Mask'] = loss_weight.flow | |
# Other custom losses. | |
self._define_custom_losses() | |
def _define_custom_losses(self): | |
r"""All other custom losses are defined here.""" | |
pass | |
def _start_of_epoch(self, current_epoch): | |
r"""Things to do before an epoch. When current_epoch is smaller than | |
$(single_frame_epoch), we only train a single frame and the generator is | |
just an image generator. After that, we start doing temporal training | |
and train multiple frames. We will double the number of training frames | |
every $(num_epochs_temporal_step) epochs. | |
Args: | |
current_epoch (int): Current number of epoch. | |
""" | |
cfg = self.cfg | |
# Only generates one frame at the beginning of training | |
if current_epoch < cfg.single_frame_epoch: | |
self.train_dataset.sequence_length = 1 | |
# Then add the temporal network to generator, and train multiple frames. | |
elif current_epoch == cfg.single_frame_epoch: | |
self.init_temporal_network() | |
# Double the length of training sequence every few epochs. | |
temp_epoch = current_epoch - cfg.single_frame_epoch | |
if temp_epoch > 0: | |
sequence_length = \ | |
cfg.data.train.initial_sequence_length * \ | |
(2 ** (temp_epoch // cfg.num_epochs_temporal_step)) | |
sequence_length = min(sequence_length, self.sequence_length_max) | |
if sequence_length > self.sequence_length: | |
self.sequence_length = sequence_length | |
self.train_dataset.set_sequence_length(sequence_length) | |
print('------- Updating sequence length to %d -------' % | |
sequence_length) | |
def init_temporal_network(self): | |
r"""Initialize temporal training when beginning to train multiple | |
frames. Set the sequence length to $(initial_sequence_length). | |
""" | |
self.tensorboard_init = False | |
# Update training sequence length. | |
self.sequence_length = self.cfg.data.train.initial_sequence_length | |
if not self.is_inference: | |
self.train_dataset.set_sequence_length(self.sequence_length) | |
print('------ Now start training %d frames -------' % | |
self.sequence_length) | |
def _start_of_iteration(self, data, current_iteration): | |
r"""Things to do before an iteration. | |
Args: | |
data (dict): Data used for the current iteration. | |
current_iteration (int): Current number of iteration. | |
""" | |
data = self.pre_process(data) | |
return to_cuda(data) | |
def pre_process(self, data): | |
r"""Do any data pre-processing here. | |
Args: | |
data (dict): Data used for the current iteration. | |
""" | |
data_cfg = self.cfg.data | |
if hasattr(data_cfg, 'for_pose_dataset') and \ | |
('pose_maps-densepose' in data_cfg.input_labels): | |
pose_cfg = data_cfg.for_pose_dataset | |
data['label'] = pre_process_densepose(pose_cfg, data['label'], | |
self.is_inference) | |
return data | |
def post_process(self, data, net_G_output): | |
r"""Do any postprocessing of the data / output here. | |
Args: | |
data (dict): Training data at the current iteration. | |
net_G_output (dict): Output of the generator. | |
""" | |
return data, net_G_output | |
def gen_update(self, data): | |
r"""Update the vid2vid generator. We update in the fashion of | |
dis_update (frame 1), gen_update (frame 1), | |
dis_update (frame 2), gen_update (frame 2), ... in each iteration. | |
Args: | |
data (dict): Training data at the current iteration. | |
""" | |
# Whether to reuse generator output for both gen_update and | |
# dis_update. It saves time but consumes a bit more memory. | |
reuse_gen_output = getattr(self.cfg.trainer, 'reuse_gen_output', True) | |
past_frames = [None, None] | |
net_G_output = None | |
data_prev = None | |
for t in range(self.sequence_length): | |
data_t = self.get_data_t(data, net_G_output, data_prev, t) | |
data_prev = data_t | |
# Discriminator update. | |
if reuse_gen_output: | |
net_G_output = self.net_G(data_t) | |
else: | |
with torch.no_grad(): | |
net_G_output = self.net_G(data_t) | |
data_t, net_G_output = self.post_process(data_t, net_G_output) | |
# Get losses and update D if image generated by network in training. | |
if 'fake_images_source' not in net_G_output: | |
net_G_output['fake_images_source'] = 'in_training' | |
if net_G_output['fake_images_source'] != 'pretrained': | |
net_D_output, _ = self.net_D(data_t, detach(net_G_output), past_frames) | |
self.get_dis_losses(net_D_output) | |
# Generator update. | |
if not reuse_gen_output: | |
net_G_output = self.net_G(data_t) | |
data_t, net_G_output = self.post_process(data_t, net_G_output) | |
# Get losses and update G if image generated by network in training. | |
if 'fake_images_source' not in net_G_output: | |
net_G_output['fake_images_source'] = 'in_training' | |
if net_G_output['fake_images_source'] != 'pretrained': | |
net_D_output, past_frames = \ | |
self.net_D(data_t, net_G_output, past_frames) | |
self.get_gen_losses(data_t, net_G_output, net_D_output) | |
# update average | |
if self.cfg.trainer.model_average_config.enabled: | |
self.net_G.module.update_average() | |
def dis_update(self, data): | |
r"""The update is already done in gen_update. | |
Args: | |
data (dict): Training data at the current iteration. | |
""" | |
pass | |
def reset(self): | |
r"""Reset the trainer (for inference) at the beginning of a sequence. | |
""" | |
# print('Resetting trainer.') | |
self.net_G_output = self.data_prev = None | |
self.t = 0 | |
self.test_in_model_average_mode = getattr( | |
self, 'test_in_model_average_mode', self.cfg.trainer.model_average_config.enabled) | |
if self.test_in_model_average_mode: | |
net_G_module = self.net_G.module.averaged_model | |
else: | |
net_G_module = self.net_G.module | |
if hasattr(net_G_module, 'reset'): | |
net_G_module.reset() | |
def create_sequence_output_dir(self, output_dir, key): | |
r"""Create output subdir for this sequence. | |
Args: | |
output_dir (str): Root output dir. | |
key (str): LMDB key which contains sequence name and file name. | |
Returns: | |
output_dir (str): Output subdir for this sequence. | |
seq_name (str): Name of this sequence. | |
""" | |
seq_dir = '/'.join(key.split('/')[:-1]) | |
output_dir = os.path.join(output_dir, seq_dir) | |
os.makedirs(output_dir, exist_ok=True) | |
seq_name = seq_dir.replace('/', '-') | |
return output_dir, seq_name | |
def test(self, test_data_loader, root_output_dir, inference_args): | |
r"""Run inference on all sequences. | |
Args: | |
test_data_loader (object): Test data loader. | |
root_output_dir (str): Location to dump outputs. | |
inference_args (optional): Optional args. | |
""" | |
# Go over all sequences. | |
loader = test_data_loader | |
num_inference_sequences = loader.dataset.num_inference_sequences() | |
for sequence_idx in range(num_inference_sequences): | |
loader.dataset.set_inference_sequence_idx(sequence_idx) | |
print('Seq id: %d, Seq length: %d' % | |
(sequence_idx + 1, len(loader))) | |
# Reset model at start of new inference sequence. | |
self.reset() | |
self.sequence_length = len(loader) | |
# Go over all frames of this sequence. | |
video = [] | |
for idx, data in enumerate(tqdm(loader)): | |
key = data['key']['images'][0][0] | |
filename = key.split('/')[-1] | |
# Create output dir for this sequence. | |
if idx == 0: | |
output_dir, seq_name = \ | |
self.create_sequence_output_dir(root_output_dir, key) | |
video_path = os.path.join(output_dir, '..', seq_name) | |
# Get output and save images. | |
data['img_name'] = filename | |
data = self.start_of_iteration(data, current_iteration=-1) | |
output = self.test_single(data, output_dir, inference_args) | |
video.append(output) | |
# Save output as mp4. | |
imageio.mimsave(video_path + '.mp4', video, fps=15) | |
def test_single(self, data, output_dir=None, inference_args=None): | |
r"""The inference function. If output_dir exists, also save the | |
output image. | |
Args: | |
data (dict): Training data at the current iteration. | |
output_dir (str): Save image directory. | |
inference_args (obj): Inference args. | |
""" | |
if getattr(inference_args, 'finetune', False): | |
if not getattr(self, 'has_finetuned', False): | |
self.finetune(data, inference_args) | |
net_G = self.net_G | |
if self.test_in_model_average_mode: | |
net_G = net_G.module.averaged_model | |
net_G.eval() | |
data_t = self.get_data_t(data, self.net_G_output, self.data_prev, 0) | |
if self.is_inference or self.sequence_length > 1: | |
self.data_prev = data_t | |
# Generator forward. | |
with torch.no_grad(): | |
self.net_G_output = net_G(data_t) | |
if output_dir is None: | |
return self.net_G_output | |
save_fake_only = getattr(inference_args, 'save_fake_only', False) | |
if save_fake_only: | |
image_grid = tensor2im(self.net_G_output['fake_images'])[0] | |
else: | |
vis_images = self.get_test_output_images(data) | |
image_grid = np.hstack([np.vstack(im) for im in | |
vis_images if im is not None]) | |
if 'img_name' in data: | |
save_name = data['img_name'].split('.')[0] + '.jpg' | |
else: | |
save_name = '%04d.jpg' % self.t | |
output_filename = os.path.join(output_dir, save_name) | |
os.makedirs(output_dir, exist_ok=True) | |
imageio.imwrite(output_filename, image_grid) | |
self.t += 1 | |
return image_grid | |
def get_test_output_images(self, data): | |
r"""Get the visualization output of test function. | |
Args: | |
data (dict): Training data at the current iteration. | |
""" | |
vis_images = [ | |
self.visualize_label(data['label'][:, -1]), | |
tensor2im(data['images'][:, -1]), | |
tensor2im(self.net_G_output['fake_images']), | |
] | |
return vis_images | |
def gen_frames(self, data, use_model_average=False): | |
r"""Generate a sequence of frames given a sequence of data. | |
Args: | |
data (dict): Training data at the current iteration. | |
use_model_average (bool): Whether to use model average | |
for update or not. | |
""" | |
net_G_output = None # Previous generator output. | |
data_prev = None # Previous data. | |
if use_model_average: | |
net_G = self.net_G.module.averaged_model | |
else: | |
net_G = self.net_G | |
# Iterate through the length of sequence. | |
all_info = {'inputs': [], 'outputs': []} | |
for t in range(self.sequence_length): | |
# Get the data at the current time frame. | |
data_t = self.get_data_t(data, net_G_output, data_prev, t) | |
data_prev = data_t | |
# Generator forward. | |
with torch.no_grad(): | |
net_G_output = net_G(data_t) | |
# Do any postprocessing if necessary. | |
data_t, net_G_output = self.post_process(data_t, net_G_output) | |
if t == 0: | |
# Get the output at beginning of sequence for visualization. | |
first_net_G_output = net_G_output | |
all_info['inputs'].append(data_t) | |
all_info['outputs'].append(net_G_output) | |
return first_net_G_output, net_G_output, all_info | |
def get_gen_losses(self, data_t, net_G_output, net_D_output): | |
r"""Compute generator losses. | |
Args: | |
data_t (dict): Training data at the current time t. | |
net_G_output (dict): Output of the generator. | |
net_D_output (dict): Output of the discriminator. | |
""" | |
update_finished = False | |
while not update_finished: | |
with autocast(enabled=self.cfg.trainer.amp_config.enabled): | |
# Individual frame GAN loss and feature matching loss. | |
self.gen_losses['GAN'], self.gen_losses['FeatureMatching'] = \ | |
self.compute_gan_losses(net_D_output['indv'], | |
dis_update=False) | |
# Perceptual loss. | |
self.gen_losses['Perceptual'] = self.criteria['Perceptual']( | |
net_G_output['fake_images'], data_t['image']) | |
# L1 loss. | |
if getattr(self.cfg.trainer.loss_weight, 'L1', 0) > 0: | |
self.gen_losses['L1'] = self.criteria['L1']( | |
net_G_output['fake_images'], data_t['image']) | |
# Raw (hallucinated) output image losses (GAN and perceptual). | |
if 'raw' in net_D_output: | |
raw_GAN_losses = self.compute_gan_losses( | |
net_D_output['raw'], dis_update=False | |
) | |
fg_mask = get_fg_mask(data_t['label'], self.has_fg) | |
raw_perceptual_loss = self.criteria['Perceptual']( | |
net_G_output['fake_raw_images'] * fg_mask, | |
data_t['image'] * fg_mask) | |
self.gen_losses['GAN'] += raw_GAN_losses[0] | |
self.gen_losses['FeatureMatching'] += raw_GAN_losses[1] | |
self.gen_losses['Perceptual'] += raw_perceptual_loss | |
# Additional discriminator losses. | |
if self.add_dis_cfg is not None: | |
for name in self.add_dis_cfg: | |
(self.gen_losses['GAN_' + name], | |
self.gen_losses['FeatureMatching_' + name]) = \ | |
self.compute_gan_losses(net_D_output[name], | |
dis_update=False) | |
# Flow and mask loss. | |
if self.use_flow: | |
(self.gen_losses['Flow_L1'], self.gen_losses['Flow_Warp'], | |
self.gen_losses['Flow_Mask']) = self.criteria['Flow']( | |
data_t, net_G_output, self.current_epoch) | |
# Temporal GAN loss and feature matching loss. | |
if self.cfg.trainer.loss_weight.temporal_gan > 0: | |
if self.sequence_length > 1: | |
for s in range(self.num_temporal_scales): | |
loss_GAN, loss_FM = self.compute_gan_losses( | |
net_D_output['temporal_%d' % s], | |
dis_update=False | |
) | |
self.gen_losses['GAN_T%d' % s] = loss_GAN | |
self.gen_losses['FeatureMatching_T%d' % s] = loss_FM | |
# Other custom losses. | |
self._get_custom_gen_losses(data_t, net_G_output, net_D_output) | |
# Sum all losses together. | |
total_loss = self.Tensor(1).fill_(0) | |
for key in self.gen_losses: | |
if key != 'total': | |
total_loss += self.gen_losses[key] * self.weights[key] | |
self.gen_losses['total'] = total_loss | |
# Zero-grad and backpropagate the loss. | |
self.opt_G.zero_grad(set_to_none=True) | |
self.scaler_G.scale(total_loss).backward() | |
# Optionally clip gradient norm. | |
if hasattr(self.cfg.gen_opt, 'clip_grad_norm'): | |
self.scaler_G.unscale_(self.opt_G) | |
total_norm = torch.nn.utils.clip_grad_norm_( | |
self.net_G_module.parameters(), | |
self.cfg.gen_opt.clip_grad_norm | |
) | |
if torch.isfinite(total_norm) and \ | |
total_norm > self.cfg.gen_opt.clip_grad_norm: | |
print(f"Gradient norm of the generator ({total_norm}) " | |
f"too large, clipping it to " | |
f"{self.cfg.gen_opt.clip_grad_norm}.") | |
# Perform an optimizer step. | |
self.scaler_G.step(self.opt_G) | |
self.scaler_G.update() | |
# Whether the step above was skipped. | |
if self.last_step_count_G == self.opt_G._step_count: | |
print("Generator overflowed!") | |
else: | |
self.last_step_count_G = self.opt_G._step_count | |
update_finished = True | |
def _get_custom_gen_losses(self, data_t, net_G_output, net_D_output): | |
r"""All other custom generator losses go here. | |
Args: | |
data_t (dict): Training data at the current time t. | |
net_G_output (dict): Output of the generator. | |
net_D_output (dict): Output of the discriminator. | |
""" | |
pass | |
def get_dis_losses(self, net_D_output): | |
r"""Compute discriminator losses. | |
Args: | |
net_D_output (dict): Output of the discriminator. | |
""" | |
update_finished = False | |
while not update_finished: | |
with autocast(enabled=self.cfg.trainer.amp_config.enabled): | |
# Individual frame GAN loss. | |
self.dis_losses['GAN'] = self.compute_gan_losses( | |
net_D_output['indv'], dis_update=True | |
) | |
# Raw (hallucinated) output image GAN loss. | |
if 'raw' in net_D_output: | |
raw_loss = self.compute_gan_losses(net_D_output['raw'], | |
dis_update=True) | |
self.dis_losses['GAN'] += raw_loss | |
# Additional GAN loss. | |
if self.add_dis_cfg is not None: | |
for name in self.add_dis_cfg: | |
self.dis_losses['GAN_' + name] = \ | |
self.compute_gan_losses(net_D_output[name], | |
dis_update=True) | |
# Temporal GAN loss. | |
if self.cfg.trainer.loss_weight.temporal_gan > 0: | |
if self.sequence_length > 1: | |
for s in range(self.num_temporal_scales): | |
self.dis_losses['GAN_T%d' % s] = \ | |
self.compute_gan_losses( | |
net_D_output['temporal_%d' % s], | |
dis_update=True | |
) | |
# Other custom losses. | |
self._get_custom_dis_losses(net_D_output) | |
# Sum all losses together. | |
total_loss = self.Tensor(1).fill_(0) | |
for key in self.dis_losses: | |
if key != 'total': | |
total_loss += self.dis_losses[key] * self.weights[key] | |
self.dis_losses['total'] = total_loss | |
# Zero-grad and backpropagate the loss. | |
self.opt_D.zero_grad(set_to_none=True) | |
self._time_before_backward() | |
self.scaler_D.scale(total_loss).backward() | |
# Optionally clip gradient norm. | |
if hasattr(self.cfg.dis_opt, 'clip_grad_norm'): | |
self.scaler_D.unscale_(self.opt_D) | |
total_norm = torch.nn.utils.clip_grad_norm_( | |
self.net_D.parameters(), self.cfg.dis_opt.clip_grad_norm | |
) | |
if torch.isfinite(total_norm) and \ | |
total_norm > self.cfg.dis_opt.clip_grad_norm: | |
print(f"Gradient norm of the discriminator ({total_norm}) " | |
f"too large, clipping it to " | |
f"{self.cfg.dis_opt.clip_grad_norm}.") | |
# Perform an optimizer step. | |
self._time_before_step() | |
self.scaler_D.step(self.opt_D) | |
self.scaler_D.update() | |
# Whether the step above was skipped. | |
if self.last_step_count_D == self.opt_D._step_count: | |
print("Discriminator overflowed!") | |
else: | |
self.last_step_count_D = self.opt_D._step_count | |
update_finished = True | |
def _get_custom_dis_losses(self, net_D_output): | |
r"""All other custom losses go here. | |
Args: | |
net_D_output (dict): Output of the discriminator. | |
""" | |
pass | |
def compute_gan_losses(self, net_D_output, dis_update): | |
r"""Compute GAN loss and feature matching loss. | |
Args: | |
net_D_output (dict): Output of the discriminator. | |
dis_update (bool): Whether to update discriminator. | |
""" | |
if net_D_output['pred_fake'] is None: | |
return self.Tensor(1).fill_(0) if dis_update else [ | |
self.Tensor(1).fill_(0), self.Tensor(1).fill_(0)] | |
if dis_update: | |
# Get the GAN loss for real/fake outputs. | |
GAN_loss = \ | |
self.criteria['GAN'](net_D_output['pred_fake']['output'], False, | |
dis_update=True) + \ | |
self.criteria['GAN'](net_D_output['pred_real']['output'], True, | |
dis_update=True) | |
return GAN_loss | |
else: | |
# Get the GAN loss and feature matching loss for fake output. | |
GAN_loss = self.criteria['GAN']( | |
net_D_output['pred_fake']['output'], True, dis_update=False) | |
FM_loss = self.criteria['FeatureMatching']( | |
net_D_output['pred_fake']['features'], | |
net_D_output['pred_real']['features']) | |
return GAN_loss, FM_loss | |
def get_data_t(self, data, net_G_output, data_prev, t): | |
r"""Get data at current time frame given the sequence of data. | |
Args: | |
data (dict): Training data for current iteration. | |
net_G_output (dict): Output of the generator (for previous frame). | |
data_prev (dict): Data for previous frame. | |
t (int): Current time. | |
""" | |
label = data['label'][:, t] | |
image = data['images'][:, t] | |
if data_prev is not None: | |
# Concat previous labels/fake images to the ones before. | |
num_frames_G = self.cfg.data.num_frames_G | |
prev_labels = concat_frames(data_prev['prev_labels'], | |
data_prev['label'], num_frames_G - 1) | |
prev_images = concat_frames( | |
data_prev['prev_images'], | |
net_G_output['fake_images'].detach(), num_frames_G - 1) | |
else: | |
prev_labels = prev_images = None | |
data_t = dict() | |
data_t['label'] = label | |
data_t['image'] = image | |
data_t['prev_labels'] = prev_labels | |
data_t['prev_images'] = prev_images | |
data_t['real_prev_image'] = data['images'][:, t - 1] if t > 0 else None | |
return data_t | |
def _end_of_iteration(self, data, current_epoch, current_iteration): | |
r"""Print the errors to console.""" | |
if not torch.distributed.is_initialized(): | |
if current_iteration % self.cfg.logging_iter == 0: | |
message = '(epoch: %d, iters: %d) ' % (current_epoch, | |
current_iteration) | |
for k, v in self.gen_losses.items(): | |
if k != 'total': | |
message += '%s: %.3f, ' % (k, v) | |
message += '\n' | |
for k, v in self.dis_losses.items(): | |
if k != 'total': | |
message += '%s: %.3f, ' % (k, v) | |
print(message) | |
def write_metrics(self): | |
r"""If moving average model presents, we have two meters one for | |
regular FID and one for average FID. If no moving average model, | |
we just report average FID. | |
""" | |
if self.cfg.trainer.model_average_config.enabled: | |
regular_fid, average_fid = self._compute_fid() | |
if regular_fid is None or average_fid is None: | |
return | |
metric_dict = {'FID/average': average_fid, 'FID/regular': regular_fid} | |
self._write_to_meters(metric_dict, self.metric_meters, reduce=False) | |
else: | |
regular_fid = self._compute_fid() | |
if regular_fid is None: | |
return | |
metric_dict = {'FID/regular': regular_fid} | |
self._write_to_meters(metric_dict, self.metric_meters, reduce=False) | |
self._flush_meters(self.metric_meters) | |
def _compute_fid(self): | |
r"""Compute FID values.""" | |
self.net_G.eval() | |
self.net_G_output = None | |
# Due to complicated video evaluation procedure we are using, we will | |
# pass the trainer to the evaluation code instead of the | |
# generator network. | |
# net_G_for_evaluation = self.net_G | |
trainer = self | |
self.test_in_model_average_mode = False | |
regular_fid_path = self._get_save_path('regular_fid', 'npy') | |
few_shot = True if 'few_shot' in self.cfg.data.type else False | |
regular_fid_value = compute_fid(regular_fid_path, self.val_data_loader, | |
trainer, | |
sample_size=self.sample_size, | |
is_video=True, few_shot_video=few_shot) | |
print('Epoch {:05}, Iteration {:09}, Regular FID {}'.format( | |
self.current_epoch, self.current_iteration, regular_fid_value)) | |
if self.cfg.trainer.model_average_config.enabled: | |
# Due to complicated video evaluation procedure we are using, | |
# we will pass the trainer to the evaluation code instead of the | |
# generator network. | |
# avg_net_G_for_evaluation = self.net_G.module.averaged_model | |
trainer_avg_mode = self | |
self.test_in_model_average_mode = True | |
# The above flag will be reset after computing FID. | |
fid_path = self._get_save_path('average_fid', 'npy') | |
few_shot = True if 'few_shot' in self.cfg.data.type else False | |
fid_value = compute_fid(fid_path, self.val_data_loader, | |
trainer_avg_mode, | |
sample_size=self.sample_size, | |
is_video=True, few_shot_video=few_shot) | |
print('Epoch {:05}, Iteration {:09}, Average FID {}'.format( | |
self.current_epoch, self.current_iteration, fid_value)) | |
self.net_G.float() | |
return regular_fid_value, fid_value | |
else: | |
self.net_G.float() | |
return regular_fid_value | |
def visualize_label(self, label): | |
r"""Visualize the input label when saving to image. | |
Args: | |
label (tensor): Input label tensor. | |
""" | |
cfgdata = self.cfg.data | |
if hasattr(cfgdata, 'for_pose_dataset'): | |
label = tensor2pose(self.cfg, label) | |
elif hasattr(cfgdata, 'input_labels') and \ | |
'seg_maps' in cfgdata.input_labels: | |
for input_type in cfgdata.input_types: | |
if 'seg_maps' in input_type: | |
num_labels = cfgdata.one_hot_num_classes.seg_maps | |
label = tensor2label(label, num_labels) | |
elif getattr(cfgdata, 'label_channels', 1) > 3: | |
label = tensor2im(label.sum(1, keepdim=True)) | |
else: | |
label = tensor2im(label) | |
return label | |
def save_image(self, path, data): | |
r"""Save the output images to path. | |
Note when the generate_raw_output is FALSE. Then, | |
first_net_G_output['fake_raw_images'] is None and will not be displayed. | |
In model average mode, we will plot the flow visualization twice. | |
Args: | |
path (str): Save path. | |
data (dict): Training data for current iteration. | |
""" | |
self.net_G.eval() | |
if self.cfg.trainer.model_average_config.enabled: | |
self.net_G.module.averaged_model.eval() | |
self.net_G_output = None | |
with torch.no_grad(): | |
first_net_G_output, net_G_output, all_info = self.gen_frames(data) | |
if self.cfg.trainer.model_average_config.enabled: | |
first_net_G_output_avg, net_G_output_avg, _ = self.gen_frames( | |
data, use_model_average=True) | |
# Visualize labels. | |
label_lengths = self.train_data_loader.dataset.get_label_lengths() | |
labels = split_labels(data['label'], label_lengths) | |
vis_labels_start, vis_labels_end = [], [] | |
for key, value in labels.items(): | |
if key == 'seg_maps': | |
vis_labels_start.append(self.visualize_label(value[:, -1])) | |
vis_labels_end.append(self.visualize_label(value[:, 0])) | |
else: | |
vis_labels_start.append(tensor2im(value[:, -1])) | |
vis_labels_end.append(tensor2im(value[:, 0])) | |
if is_master(): | |
vis_images = [ | |
*vis_labels_start, | |
tensor2im(data['images'][:, -1]), | |
tensor2im(net_G_output['fake_images']), | |
tensor2im(net_G_output['fake_raw_images'])] | |
if self.cfg.trainer.model_average_config.enabled: | |
vis_images += [ | |
tensor2im(net_G_output_avg['fake_images']), | |
tensor2im(net_G_output_avg['fake_raw_images'])] | |
if self.sequence_length > 1: | |
vis_images_first = [ | |
*vis_labels_end, | |
tensor2im(data['images'][:, 0]), | |
tensor2im(first_net_G_output['fake_images']), | |
tensor2im(first_net_G_output['fake_raw_images']) | |
] | |
if self.cfg.trainer.model_average_config.enabled: | |
vis_images_first += [ | |
tensor2im(first_net_G_output_avg['fake_images']), | |
tensor2im(first_net_G_output_avg['fake_raw_images'])] | |
if self.use_flow: | |
flow_gt, conf_gt = self.criteria['Flow'].flowNet( | |
data['images'][:, -1], data['images'][:, -2]) | |
warped_image_gt = resample(data['images'][:, -1], flow_gt) | |
vis_images_first += [ | |
tensor2flow(flow_gt), | |
tensor2im(conf_gt, normalize=False), | |
tensor2im(warped_image_gt), | |
] | |
vis_images += [ | |
tensor2flow(net_G_output['fake_flow_maps']), | |
tensor2im(net_G_output['fake_occlusion_masks'], | |
normalize=False), | |
tensor2im(net_G_output['warped_images']), | |
] | |
if self.cfg.trainer.model_average_config.enabled: | |
vis_images_first += [ | |
tensor2flow(flow_gt), | |
tensor2im(conf_gt, normalize=False), | |
tensor2im(warped_image_gt), | |
] | |
vis_images += [ | |
tensor2flow(net_G_output_avg['fake_flow_maps']), | |
tensor2im(net_G_output_avg['fake_occlusion_masks'], | |
normalize=False), | |
tensor2im(net_G_output_avg['warped_images'])] | |
vis_images = [[np.vstack((im_first, im)) | |
for im_first, im in zip(imgs_first, imgs)] | |
for imgs_first, imgs in zip(vis_images_first, | |
vis_images) | |
if imgs is not None] | |
image_grid = np.hstack([np.vstack(im) for im in | |
vis_images if im is not None]) | |
print('Save output images to {}'.format(path)) | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
imageio.imwrite(path, image_grid) | |
# Gather all outputs for dumping into video. | |
if self.sequence_length > 1: | |
output_images = [] | |
for item in all_info['outputs']: | |
output_images.append(tensor2im(item['fake_images'])[0]) | |
imageio.mimwrite(os.path.splitext(path)[0] + '.mp4', | |
output_images, fps=2, macro_block_size=None) | |
self.net_G.float() | |