File size: 4,350 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import math
from datetime import timedelta

import torch
import wandb
from wandb import AlertLevel
from torch.utils.tensorboard import SummaryWriter

from imaginaire.utils.distributed import master_only, dist_all_reduce_tensor, \
    is_master, get_rank

from imaginaire.utils.distributed import master_only_print as print

LOG_WRITER = None
LOG_DIR = None


@torch.no_grad()
def sn_reshape_weight_to_matrix(weight):
    r"""Reshape weight to obtain the matrix form.

    Args:
        weight (Parameters): pytorch layer parameter tensor.
    """
    weight_mat = weight
    height = weight_mat.size(0)
    return weight_mat.reshape(height, -1)


@torch.no_grad()
def get_weight_stats(mod):
    r"""Get weight state

    Args:
         mod: Pytorch module
    """
    if mod.weight_orig.grad is not None:
        grad_norm = mod.weight_orig.grad.data.norm().item()
    else:
        grad_norm = 0.
    weight_norm = mod.weight_orig.data.norm().item()
    weight_mat = sn_reshape_weight_to_matrix(mod.weight_orig)
    sigma = torch.sum(mod.weight_u * torch.mv(weight_mat, mod.weight_v))
    return grad_norm, weight_norm, sigma


@master_only
def set_summary_writer(log_dir):
    r"""Set summary writer

    Args:
        log_dir (str): Log directory.
    """
    global LOG_DIR, LOG_WRITER
    LOG_DIR = log_dir
    LOG_WRITER = SummaryWriter(log_dir=log_dir)


def write_summary(name, summary, step, hist=False):
    """Utility function for write summary to log_writer.
    """
    global LOG_WRITER
    lw = LOG_WRITER
    if lw is None:
        raise Exception("Log writer not set.")
    if hist:
        lw.add_histogram(name, summary, step)
    else:
        lw.add_scalar(name, summary, step)


class Meter(object):
    """Meter is to keep track of statistics along steps.
    Meters write values for purpose like printing average values.
    Meters can be flushed to log files (i.e. TensorBoard for now)
    regularly.

    Args:
        name (str): the name of meter
        reduce (bool): If ``True``, perform a distributed reduce for the log
            values across all GPUs.
    """

    def __init__(self, name, reduce=True):
        self.name = name
        self.reduce = reduce
        self.values = []

    def reset(self):
        r"""Reset the meter values"""
        if not self.reduce and get_rank() != 0:
            return
        self.values = []

    def write(self, value):
        r"""Record the value"""
        if not self.reduce and get_rank() != 0:
            return
        if value is not None:
            self.values.append(value)

    def flush(self, step):
        r"""Write the value in the tensorboard.

        Args:
            step (int): Epoch or iteration number.
        """
        if not self.reduce and get_rank() != 0:
            return
        values = torch.tensor(self.values, device="cuda")
        if self.reduce:
            values = dist_all_reduce_tensor(values)

        if not all(math.isfinite(x) for x in values):
            print("meter {} contained a nan or inf.".format(self.name))
            if is_master():
                wandb.alert(
                    title='NaN',
                    text=f'Meter {self.name} contained a nan or inf.',
                    level=AlertLevel.WARN,
                    wait_duration=timedelta(minutes=120)
                )
        filtered_values = list(filter(lambda x: math.isfinite(x), self.values))
        if float(len(filtered_values)) != 0:
            value = float(sum(filtered_values)) / float(len(filtered_values))
            if is_master():
                write_summary(self.name, value, step)
                wandb.log({self.name: value}, step=step)
        self.reset()

    @master_only
    def write_image(self, img_grid, step):
        r"""Write the value in the tensorboard.

        Args:
            img_grid:
            step (int): Epoch or iteration number.
        """
        if not self.reduce and get_rank() != 0:
            return
        global LOG_WRITER
        lw = LOG_WRITER
        if lw is None:
            raise Exception("Log writer not set.")
        lw.add_image("Visualizations", img_grid, step)