File size: 4,069 Bytes
32b542e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import torch.distributed as dist
from torch._six import inf
import io
from timm.utils import get_state_dict
try:
    from apex import amp
    APEX_INSTALLED = True
except:
    print('apex has not been installed.')
    APEX_INSTALLED = False


class NativeScalerWithGradNormCount:
    state_dict_key = "amp_scaler"

    def __init__(self, enabled=True, growth_interval=500, init_scale=2.**13):
        self.enabled = enabled
        self._scaler = torch.cuda.amp.GradScaler(init_scale=init_scale, growth_interval=growth_interval, enabled=self.enabled)

    def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True,
        fp16=False, iter=0, min_loss_scale= 2048.0, loss_scale_window=200):
        self._scaler.scale(loss).backward(create_graph=create_graph)

        if fp16:
            # used for stable training
            if iter > 5000 and self._scaler.get_scale() < min_loss_scale:
                min_growth_interval = 5
                if self._scaler.get_growth_interval() != min_growth_interval:
                    self._scaler.set_growth_interval(min_growth_interval)

            elif iter > 5000 and self._scaler.get_growth_interval() == 5:
                self._scaler.set_growth_interval(loss_scale_window)

        if update_grad:
            if clip_grad is not None and clip_grad > 0.0:
                assert parameters is not None
                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
            else:
                self._scaler.unscale_(optimizer)
                norm = get_grad_norm_(parameters)
        else:
            norm = None
        return norm

    def step(self, optimizer):
        self._scaler.step(optimizer)

    def update(self):
        self._scaler.update()

    def get_scale(self):
        return self._scaler.get_scale()

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)


class ApexScalerWithGradNormCount:
    state_dict_key = "amp_scaler"

    def __init__(self, enabled=True):
        self.enabled = enabled
        self._scaler = amp

    def __call__(self,
                 loss,
                 optimizer,
                 clip_grad=None,
                 parameters=None,
                 create_graph=False,
                 update_grad=True,
                 fp16=False,
                 iter=0,
                 min_loss_scale=2048.0,
                 loss_scale_window=200):

        with self._scaler.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()

        if update_grad:
            if clip_grad is not None and clip_grad > 0.0:
                norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), clip_grad)
            else:

                norm = get_grad_norm_(amp.master_params(optimizer))
        else:
            norm = None
        return norm

    def step(self, optimizer):
        optimizer.step()


    def update(self):
        pass

    def get_scale(self):
        return self._scaler.state_dict()['loss_scaler0']['loss_scale']

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)


def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    norm_type = float(norm_type)
    if len(parameters) == 0:
        return torch.tensor(0.)
    device = parameters[0].grad.device
    if norm_type == inf:
        total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
    else:
        total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
    return total_norm