File size: 3,365 Bytes
08ab988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from typing import *
import torch
import numpy as np
import torch.utils


class AdaptiveGradClipper:
    """

    Adaptive gradient clipping for training.

    """
    def __init__(

        self,

        max_norm=None,

        clip_percentile=95.0,

        buffer_size=1000,

    ):
        self.max_norm = max_norm
        self.clip_percentile = clip_percentile
        self.buffer_size = buffer_size
        
        self._grad_norm = np.zeros(buffer_size, dtype=np.float32)
        self._max_norm = max_norm
        self._buffer_ptr = 0
        self._buffer_length = 0

    def __repr__(self):
        return f'AdaptiveGradClipper(max_norm={self.max_norm}, clip_percentile={self.clip_percentile})'
        
    def state_dict(self):
        return {
            'grad_norm': self._grad_norm,
            'max_norm': self._max_norm,
            'buffer_ptr': self._buffer_ptr,
            'buffer_length': self._buffer_length,
        }

    def load_state_dict(self, state_dict):
        self._grad_norm = state_dict['grad_norm']
        self._max_norm = state_dict['max_norm']
        self._buffer_ptr = state_dict['buffer_ptr']
        self._buffer_length = state_dict['buffer_length']

    def log(self):
        return {
            'max_norm': self._max_norm,
        }

    def __call__(self, parameters, norm_type=2.0, error_if_nonfinite=False, foreach=None):
        """Clip the gradient norm of an iterable of parameters.



        The norm is computed over all gradients together, as if they were

        concatenated into a single vector. Gradients are modified in-place.



        Args:

            parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a

                single Tensor that will have gradients normalized

            norm_type (float): type of the used p-norm. Can be ``'inf'`` for

                infinity norm.

            error_if_nonfinite (bool): if True, an error is thrown if the total

                norm of the gradients from :attr:`parameters` is ``nan``,

                ``inf``, or ``-inf``. Default: False (will switch to True in the future)

            foreach (bool): use the faster foreach-based implementation.

                If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently

                fall back to the slow implementation for other device types.

                Default: ``None``



        Returns:

            Total norm of the parameter gradients (viewed as a single vector).

        """
        max_norm = self._max_norm if self._max_norm is not None else float('inf')
        grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach)
        
        if torch.isfinite(grad_norm):
            self._grad_norm[self._buffer_ptr] = grad_norm
            self._buffer_ptr = (self._buffer_ptr + 1) % self.buffer_size
            self._buffer_length = min(self._buffer_length + 1, self.buffer_size)
            if self._buffer_length == self.buffer_size:
                self._max_norm = np.percentile(self._grad_norm, self.clip_percentile)
                self._max_norm = min(self._max_norm, self.max_norm) if self.max_norm is not None else self._max_norm
        
        return grad_norm