File size: 4,887 Bytes
b6c45cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
from functools import partial
import torch
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm
EPS = 1e-8
class _LayerNorm(nn.Module):
"""Layer Normalization base class."""
def __init__(self, channel_size):
super(_LayerNorm, self).__init__()
self.channel_size = channel_size
self.gamma = nn.Parameter(torch.ones(channel_size), requires_grad=True)
self.beta = nn.Parameter(torch.zeros(channel_size), requires_grad=True)
def apply_gain_and_bias(self, normed_x):
""" Assumes input of size `[batch, chanel, *]`. """
return (self.gamma * normed_x.transpose(1, -1) + self.beta).transpose(1, -1)
class GlobLN(_LayerNorm):
"""Global Layer Normalization (globLN)."""
def forward(self, x):
"""Applies forward pass.
Works for any input size > 2D.
Args:
x (:class:`torch.Tensor`): Shape `[batch, chan, *]`
Returns:
:class:`torch.Tensor`: gLN_x `[batch, chan, *]`
"""
dims = list(range(1, len(x.shape)))
mean = x.mean(dim=dims, keepdim=True)
var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True)
return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
class ChanLN(_LayerNorm):
"""Channel-wise Layer Normalization (chanLN)."""
def forward(self, x):
"""Applies forward pass.
Works for any input size > 2D.
Args:
x (:class:`torch.Tensor`): `[batch, chan, *]`
Returns:
:class:`torch.Tensor`: chanLN_x `[batch, chan, *]`
"""
mean = torch.mean(x, dim=1, keepdim=True)
var = torch.var(x, dim=1, keepdim=True, unbiased=False)
return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
class CumLN(_LayerNorm):
"""Cumulative Global layer normalization(cumLN)."""
def forward(self, x):
"""
Args:
x (:class:`torch.Tensor`): Shape `[batch, channels, length]`
Returns:
:class:`torch.Tensor`: cumLN_x `[batch, channels, length]`
"""
batch, chan, spec_len = x.size()
cum_sum = torch.cumsum(x.sum(1, keepdim=True), dim=-1)
cum_pow_sum = torch.cumsum(x.pow(2).sum(1, keepdim=True), dim=-1)
cnt = torch.arange(start=chan, end=chan * (spec_len + 1), step=chan, dtype=x.dtype).view(
1, 1, -1
)
cum_mean = cum_sum / cnt
cum_var = cum_pow_sum - cum_mean.pow(2)
return self.apply_gain_and_bias((x - cum_mean) / (cum_var + EPS).sqrt())
class FeatsGlobLN(_LayerNorm):
"""feature-wise global Layer Normalization (FeatsGlobLN).
Applies normalization over frames for each channel."""
def forward(self, x):
"""Applies forward pass.
Works for any input size > 2D.
Args:
x (:class:`torch.Tensor`): `[batch, chan, time]`
Returns:
:class:`torch.Tensor`: chanLN_x `[batch, chan, time]`
"""
stop = len(x.size())
dims = list(range(2, stop))
mean = torch.mean(x, dim=dims, keepdim=True)
var = torch.var(x, dim=dims, keepdim=True, unbiased=False)
return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
class BatchNorm(_BatchNorm):
"""Wrapper class for pytorch BatchNorm1D and BatchNorm2D"""
def _check_input_dim(self, input):
if input.dim() < 2 or input.dim() > 4:
raise ValueError("expected 4D or 3D input (got {}D input)".format(input.dim()))
# Aliases.
gLN = GlobLN
fgLN = FeatsGlobLN
cLN = ChanLN
cgLN = CumLN
bN = BatchNorm
def register_norm(custom_norm):
"""Register a custom norm, gettable with `norms.get`.
Args:
custom_norm: Custom norm to register.
"""
if custom_norm.__name__ in globals().keys() or custom_norm.__name__.lower() in globals().keys():
raise ValueError(f"Norm {custom_norm.__name__} already exists. Choose another name.")
globals().update({custom_norm.__name__: custom_norm})
def get(identifier):
"""Returns a norm class from a string. Returns its input if it
is callable (already a :class:`._LayerNorm` for example).
Args:
identifier (str or Callable or None): the norm identifier.
Returns:
:class:`._LayerNorm` or None
"""
if identifier is None:
return None
elif callable(identifier):
return identifier
elif isinstance(identifier, str):
cls = globals().get(identifier)
if cls is None:
raise ValueError("Could not interpret normalization identifier: " + str(identifier))
return cls
else:
raise ValueError("Could not interpret normalization identifier: " + str(identifier))
|