File size: 4,887 Bytes
b6c45cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from functools import partial
import torch
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm

EPS = 1e-8


class _LayerNorm(nn.Module):
    """Layer Normalization base class."""

    def __init__(self, channel_size):
        super(_LayerNorm, self).__init__()
        self.channel_size = channel_size
        self.gamma = nn.Parameter(torch.ones(channel_size), requires_grad=True)
        self.beta = nn.Parameter(torch.zeros(channel_size), requires_grad=True)

    def apply_gain_and_bias(self, normed_x):
        """ Assumes input of size `[batch, chanel, *]`. """
        return (self.gamma * normed_x.transpose(1, -1) + self.beta).transpose(1, -1)


class GlobLN(_LayerNorm):
    """Global Layer Normalization (globLN)."""

    def forward(self, x):
        """Applies forward pass.



        Works for any input size > 2D.



        Args:

            x (:class:`torch.Tensor`): Shape `[batch, chan, *]`



        Returns:

            :class:`torch.Tensor`: gLN_x `[batch, chan, *]`

        """
        dims = list(range(1, len(x.shape)))
        mean = x.mean(dim=dims, keepdim=True)
        var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True)
        return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())


class ChanLN(_LayerNorm):
    """Channel-wise Layer Normalization (chanLN)."""

    def forward(self, x):
        """Applies forward pass.



        Works for any input size > 2D.



        Args:

            x (:class:`torch.Tensor`): `[batch, chan, *]`



        Returns:

            :class:`torch.Tensor`: chanLN_x `[batch, chan, *]`

        """
        mean = torch.mean(x, dim=1, keepdim=True)
        var = torch.var(x, dim=1, keepdim=True, unbiased=False)
        return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())


class CumLN(_LayerNorm):
    """Cumulative Global layer normalization(cumLN)."""

    def forward(self, x):
        """



        Args:

            x (:class:`torch.Tensor`): Shape `[batch, channels, length]`

        Returns:

             :class:`torch.Tensor`: cumLN_x `[batch, channels, length]`

        """
        batch, chan, spec_len = x.size()
        cum_sum = torch.cumsum(x.sum(1, keepdim=True), dim=-1)
        cum_pow_sum = torch.cumsum(x.pow(2).sum(1, keepdim=True), dim=-1)
        cnt = torch.arange(start=chan, end=chan * (spec_len + 1), step=chan, dtype=x.dtype).view(
            1, 1, -1
        )
        cum_mean = cum_sum / cnt
        cum_var = cum_pow_sum - cum_mean.pow(2)
        return self.apply_gain_and_bias((x - cum_mean) / (cum_var + EPS).sqrt())


class FeatsGlobLN(_LayerNorm):
    """feature-wise global Layer Normalization (FeatsGlobLN).

    Applies normalization over frames for each channel."""

    def forward(self, x):
        """Applies forward pass.



        Works for any input size > 2D.



        Args:

            x (:class:`torch.Tensor`): `[batch, chan, time]`



        Returns:

            :class:`torch.Tensor`: chanLN_x `[batch, chan, time]`

        """

        stop = len(x.size())
        dims = list(range(2, stop))

        mean = torch.mean(x, dim=dims, keepdim=True)
        var = torch.var(x, dim=dims, keepdim=True, unbiased=False)
        return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())


class BatchNorm(_BatchNorm):
    """Wrapper class for pytorch BatchNorm1D and BatchNorm2D"""

    def _check_input_dim(self, input):
        if input.dim() < 2 or input.dim() > 4:
            raise ValueError("expected 4D or 3D input (got {}D input)".format(input.dim()))


# Aliases.
gLN = GlobLN
fgLN = FeatsGlobLN
cLN = ChanLN
cgLN = CumLN
bN = BatchNorm


def register_norm(custom_norm):
    """Register a custom norm, gettable with `norms.get`.



    Args:

        custom_norm: Custom norm to register.



    """
    if custom_norm.__name__ in globals().keys() or custom_norm.__name__.lower() in globals().keys():
        raise ValueError(f"Norm {custom_norm.__name__} already exists. Choose another name.")
    globals().update({custom_norm.__name__: custom_norm})


def get(identifier):
    """Returns a norm class from a string. Returns its input if it

    is callable (already a :class:`._LayerNorm` for example).



    Args:

        identifier (str or Callable or None): the norm identifier.



    Returns:

        :class:`._LayerNorm` or None

    """
    if identifier is None:
        return None
    elif callable(identifier):
        return identifier
    elif isinstance(identifier, str):
        cls = globals().get(identifier)
        if cls is None:
            raise ValueError("Could not interpret normalization identifier: " + str(identifier))
        return cls
    else:
        raise ValueError("Could not interpret normalization identifier: " + str(identifier))