Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import copy | |
import torch | |
from torch import nn | |
from imaginaire.layers.weight_norm import remove_weight_norms | |
from imaginaire.utils.misc import requires_grad | |
def reset_batch_norm(m): | |
r"""Reset batch norm statistics | |
Args: | |
m: Pytorch module | |
""" | |
if hasattr(m, 'reset_running_stats'): | |
m.reset_running_stats() | |
def calibrate_batch_norm_momentum(m): | |
r"""Calibrate batch norm momentum | |
Args: | |
m: Pytorch module | |
""" | |
if hasattr(m, 'reset_running_stats'): | |
# if m._get_name() == 'SyncBatchNorm': | |
if 'BatchNorm' in m._get_name(): | |
m.momentum = 1.0 / float(m.num_batches_tracked + 1) | |
class ModelAverage(nn.Module): | |
r"""In this model average implementation, the spectral layers are | |
absorbed in the model parameter by default. If such options are | |
turned on, be careful with how you do the training. Remember to | |
re-estimate the batch norm parameters before using the model. | |
Args: | |
module (torch nn module): Torch network. | |
beta (float): Moving average weights. How much we weight the past. | |
start_iteration (int): From which iteration, we start the update. | |
remove_sn (bool): Whether we remove the spectral norm when we it. | |
""" | |
def __init__( | |
self, module, beta=0.9999, start_iteration=1000, | |
remove_wn_wrapper=True | |
): | |
super(ModelAverage, self).__init__() | |
self.module = module | |
# A shallow copy creates a new object which stores the reference of | |
# the original elements. | |
# A deep copy creates a new object and recursively adds the copies of | |
# nested objects present in the original elements. | |
self.averaged_model = copy.deepcopy(self.module).to('cuda') | |
self.beta = beta | |
self.remove_wn_wrapper = remove_wn_wrapper | |
self.start_iteration = start_iteration | |
# This buffer is to track how many iterations has the model been | |
# trained for. We will ignore the first $(start_iterations) and start | |
# the averaging after. | |
self.register_buffer('num_updates_tracked', | |
torch.tensor(0, dtype=torch.long)) | |
self.num_updates_tracked = self.num_updates_tracked.to('cuda') | |
# if self.remove_sn: | |
# # If we want to remove the spectral norm, we first copy the | |
# # weights to the moving average model. | |
# self.copy_s2t() | |
# | |
# def fn_remove_sn(m): | |
# r"""Remove spectral norm.""" | |
# if hasattr(m, 'weight_orig'): | |
# remove_spectral_norm(m) | |
# | |
# self.averaged_model.apply(fn_remove_sn) | |
# self.dim = 0 | |
if self.remove_wn_wrapper: | |
self.copy_s2t() | |
self.averaged_model.apply(remove_weight_norms) | |
self.dim = 0 | |
else: | |
self.averaged_model.eval() | |
# Averaged model does not require grad. | |
requires_grad(self.averaged_model, False) | |
def forward(self, *inputs, **kwargs): | |
r"""PyTorch module forward function overload.""" | |
return self.module(*inputs, **kwargs) | |
def update_average(self): | |
r"""Update the moving average.""" | |
self.num_updates_tracked += 1 | |
if self.num_updates_tracked <= self.start_iteration: | |
beta = 0. | |
else: | |
beta = self.beta | |
source_dict = self.module.state_dict() | |
target_dict = self.averaged_model.state_dict() | |
for key in target_dict: | |
if 'num_batches_tracked' in key: | |
continue | |
if self.remove_wn_wrapper: | |
if key.endswith('weight'): | |
# This is a weight parameter. | |
if key + '_ori' in source_dict: | |
# This parameter has scaled lr. | |
source_param = \ | |
source_dict[key + '_ori'] * \ | |
source_dict[key + '_scale'] | |
elif key + '_orig' in source_dict: | |
# This parameter has spectral norm | |
# but not scaled lr. | |
source_param = source_dict[key + '_orig'] | |
elif key in source_dict: | |
# This parameter does not have | |
# weight normalization wrappers. | |
source_param = source_dict[key] | |
else: | |
raise ValueError( | |
f"{key} required in the averaged model but not " | |
f"found in the regular model." | |
) | |
source_param = source_param.detach() | |
if key + '_orig' in source_dict: | |
# This parameter has spectral norm. | |
source_param = self.sn_compute_weight( | |
source_param, | |
source_dict[key + '_u'], | |
source_dict[key + '_v'], | |
) | |
elif key.endswith('bias') and key + '_ori' in source_dict: | |
# This is a bias parameter and has scaled lr. | |
source_param = source_dict[key + '_ori'] * \ | |
source_dict[key + '_scale'] | |
else: | |
# This is a normal parameter. | |
source_param = source_dict[key] | |
target_dict[key].data.mul_(beta).add_( | |
source_param.data, alpha=1 - beta | |
) | |
else: | |
target_dict[key].data.mul_(beta).add_( | |
source_dict[key].data, alpha=1 - beta | |
) | |
def copy_t2s(self): | |
r"""Copy the original weights to the moving average weights.""" | |
target_dict = self.module.state_dict() | |
source_dict = self.averaged_model.state_dict() | |
beta = 0. | |
for key in source_dict: | |
target_dict[key].data.copy_( | |
target_dict[key].data * beta + | |
source_dict[key].data * (1 - beta)) | |
def copy_s2t(self): | |
r""" Copy state_dictionary from source to target. | |
Here source is the regular module and the target is the moving | |
average module. Basically, we will copy weights in the regular module | |
to the moving average module. | |
""" | |
source_dict = self.module.state_dict() | |
target_dict = self.averaged_model.state_dict() | |
beta = 0. | |
for key in source_dict: | |
target_dict[key].data.copy_( | |
target_dict[key].data * beta + | |
source_dict[key].data * (1 - beta)) | |
def __repr__(self): | |
r"""Returns a string that holds a printable representation of an | |
object""" | |
return self.module.__repr__() | |
def sn_reshape_weight_to_matrix(self, weight): | |
r"""Reshape weight to obtain the matrix form. | |
Args: | |
weight (Parameters): pytorch layer parameter tensor. | |
Returns: | |
(Parameters): Reshaped weight matrix | |
""" | |
weight_mat = weight | |
if self.dim != 0: | |
# permute dim to front | |
weight_mat = weight_mat.permute( | |
self.dim, | |
*[d for d in range(weight_mat.dim()) if d != self.dim]) | |
height = weight_mat.size(0) | |
return weight_mat.reshape(height, -1) | |
def sn_compute_weight(self, weight, u, v): | |
r"""Compute the spectral norm normalized matrix. | |
Args: | |
weight (Parameters): pytorch layer parameter tensor. | |
u (tensor): left singular vectors. | |
v (tensor) right singular vectors | |
Returns: | |
(Parameters): weight parameter object. | |
""" | |
weight_mat = self.sn_reshape_weight_to_matrix(weight) | |
sigma = torch.sum(u * torch.mv(weight_mat, v)) | |
weight = weight / sigma | |
return weight | |