Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import math | |
from datetime import timedelta | |
import torch | |
import wandb | |
from wandb import AlertLevel | |
from torch.utils.tensorboard import SummaryWriter | |
from imaginaire.utils.distributed import master_only, dist_all_reduce_tensor, \ | |
is_master, get_rank | |
from imaginaire.utils.distributed import master_only_print as print | |
LOG_WRITER = None | |
LOG_DIR = None | |
def sn_reshape_weight_to_matrix(weight): | |
r"""Reshape weight to obtain the matrix form. | |
Args: | |
weight (Parameters): pytorch layer parameter tensor. | |
""" | |
weight_mat = weight | |
height = weight_mat.size(0) | |
return weight_mat.reshape(height, -1) | |
def get_weight_stats(mod): | |
r"""Get weight state | |
Args: | |
mod: Pytorch module | |
""" | |
if mod.weight_orig.grad is not None: | |
grad_norm = mod.weight_orig.grad.data.norm().item() | |
else: | |
grad_norm = 0. | |
weight_norm = mod.weight_orig.data.norm().item() | |
weight_mat = sn_reshape_weight_to_matrix(mod.weight_orig) | |
sigma = torch.sum(mod.weight_u * torch.mv(weight_mat, mod.weight_v)) | |
return grad_norm, weight_norm, sigma | |
def set_summary_writer(log_dir): | |
r"""Set summary writer | |
Args: | |
log_dir (str): Log directory. | |
""" | |
global LOG_DIR, LOG_WRITER | |
LOG_DIR = log_dir | |
LOG_WRITER = SummaryWriter(log_dir=log_dir) | |
def write_summary(name, summary, step, hist=False): | |
"""Utility function for write summary to log_writer. | |
""" | |
global LOG_WRITER | |
lw = LOG_WRITER | |
if lw is None: | |
raise Exception("Log writer not set.") | |
if hist: | |
lw.add_histogram(name, summary, step) | |
else: | |
lw.add_scalar(name, summary, step) | |
class Meter(object): | |
"""Meter is to keep track of statistics along steps. | |
Meters write values for purpose like printing average values. | |
Meters can be flushed to log files (i.e. TensorBoard for now) | |
regularly. | |
Args: | |
name (str): the name of meter | |
reduce (bool): If ``True``, perform a distributed reduce for the log | |
values across all GPUs. | |
""" | |
def __init__(self, name, reduce=True): | |
self.name = name | |
self.reduce = reduce | |
self.values = [] | |
def reset(self): | |
r"""Reset the meter values""" | |
if not self.reduce and get_rank() != 0: | |
return | |
self.values = [] | |
def write(self, value): | |
r"""Record the value""" | |
if not self.reduce and get_rank() != 0: | |
return | |
if value is not None: | |
self.values.append(value) | |
def flush(self, step): | |
r"""Write the value in the tensorboard. | |
Args: | |
step (int): Epoch or iteration number. | |
""" | |
if not self.reduce and get_rank() != 0: | |
return | |
values = torch.tensor(self.values, device="cuda") | |
if self.reduce: | |
values = dist_all_reduce_tensor(values) | |
if not all(math.isfinite(x) for x in values): | |
print("meter {} contained a nan or inf.".format(self.name)) | |
if is_master(): | |
wandb.alert( | |
title='NaN', | |
text=f'Meter {self.name} contained a nan or inf.', | |
level=AlertLevel.WARN, | |
wait_duration=timedelta(minutes=120) | |
) | |
filtered_values = list(filter(lambda x: math.isfinite(x), self.values)) | |
if float(len(filtered_values)) != 0: | |
value = float(sum(filtered_values)) / float(len(filtered_values)) | |
if is_master(): | |
write_summary(self.name, value, step) | |
wandb.log({self.name: value}, step=step) | |
self.reset() | |
def write_image(self, img_grid, step): | |
r"""Write the value in the tensorboard. | |
Args: | |
img_grid: | |
step (int): Epoch or iteration number. | |
""" | |
if not self.reduce and get_rank() != 0: | |
return | |
global LOG_WRITER | |
lw = LOG_WRITER | |
if lw is None: | |
raise Exception("Log writer not set.") | |
lw.add_image("Visualizations", img_grid, step) | |