Spaces:
Sleeping
Sleeping
from functools import partial | |
import torch | |
from torch import nn | |
from torch.nn.modules.batchnorm import _BatchNorm | |
EPS = 1e-8 | |
class _LayerNorm(nn.Module): | |
"""Layer Normalization base class.""" | |
def __init__(self, channel_size): | |
super(_LayerNorm, self).__init__() | |
self.channel_size = channel_size | |
self.gamma = nn.Parameter(torch.ones(channel_size), requires_grad=True) | |
self.beta = nn.Parameter(torch.zeros(channel_size), requires_grad=True) | |
def apply_gain_and_bias(self, normed_x): | |
""" Assumes input of size `[batch, chanel, *]`. """ | |
return (self.gamma * normed_x.transpose(1, -1) + self.beta).transpose(1, -1) | |
class GlobLN(_LayerNorm): | |
"""Global Layer Normalization (globLN).""" | |
def forward(self, x): | |
"""Applies forward pass. | |
Works for any input size > 2D. | |
Args: | |
x (:class:`torch.Tensor`): Shape `[batch, chan, *]` | |
Returns: | |
:class:`torch.Tensor`: gLN_x `[batch, chan, *]` | |
""" | |
dims = list(range(1, len(x.shape))) | |
mean = x.mean(dim=dims, keepdim=True) | |
var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True) | |
return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt()) | |
class ChanLN(_LayerNorm): | |
"""Channel-wise Layer Normalization (chanLN).""" | |
def forward(self, x): | |
"""Applies forward pass. | |
Works for any input size > 2D. | |
Args: | |
x (:class:`torch.Tensor`): `[batch, chan, *]` | |
Returns: | |
:class:`torch.Tensor`: chanLN_x `[batch, chan, *]` | |
""" | |
mean = torch.mean(x, dim=1, keepdim=True) | |
var = torch.var(x, dim=1, keepdim=True, unbiased=False) | |
return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt()) | |
class CumLN(_LayerNorm): | |
"""Cumulative Global layer normalization(cumLN).""" | |
def forward(self, x): | |
""" | |
Args: | |
x (:class:`torch.Tensor`): Shape `[batch, channels, length]` | |
Returns: | |
:class:`torch.Tensor`: cumLN_x `[batch, channels, length]` | |
""" | |
batch, chan, spec_len = x.size() | |
cum_sum = torch.cumsum(x.sum(1, keepdim=True), dim=-1) | |
cum_pow_sum = torch.cumsum(x.pow(2).sum(1, keepdim=True), dim=-1) | |
cnt = torch.arange(start=chan, end=chan * (spec_len + 1), step=chan, dtype=x.dtype).view( | |
1, 1, -1 | |
) | |
cum_mean = cum_sum / cnt | |
cum_var = cum_pow_sum - cum_mean.pow(2) | |
return self.apply_gain_and_bias((x - cum_mean) / (cum_var + EPS).sqrt()) | |
class FeatsGlobLN(_LayerNorm): | |
"""feature-wise global Layer Normalization (FeatsGlobLN). | |
Applies normalization over frames for each channel.""" | |
def forward(self, x): | |
"""Applies forward pass. | |
Works for any input size > 2D. | |
Args: | |
x (:class:`torch.Tensor`): `[batch, chan, time]` | |
Returns: | |
:class:`torch.Tensor`: chanLN_x `[batch, chan, time]` | |
""" | |
stop = len(x.size()) | |
dims = list(range(2, stop)) | |
mean = torch.mean(x, dim=dims, keepdim=True) | |
var = torch.var(x, dim=dims, keepdim=True, unbiased=False) | |
return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt()) | |
class BatchNorm(_BatchNorm): | |
"""Wrapper class for pytorch BatchNorm1D and BatchNorm2D""" | |
def _check_input_dim(self, input): | |
if input.dim() < 2 or input.dim() > 4: | |
raise ValueError("expected 4D or 3D input (got {}D input)".format(input.dim())) | |
# Aliases. | |
gLN = GlobLN | |
fgLN = FeatsGlobLN | |
cLN = ChanLN | |
cgLN = CumLN | |
bN = BatchNorm | |
def register_norm(custom_norm): | |
"""Register a custom norm, gettable with `norms.get`. | |
Args: | |
custom_norm: Custom norm to register. | |
""" | |
if custom_norm.__name__ in globals().keys() or custom_norm.__name__.lower() in globals().keys(): | |
raise ValueError(f"Norm {custom_norm.__name__} already exists. Choose another name.") | |
globals().update({custom_norm.__name__: custom_norm}) | |
def get(identifier): | |
"""Returns a norm class from a string. Returns its input if it | |
is callable (already a :class:`._LayerNorm` for example). | |
Args: | |
identifier (str or Callable or None): the norm identifier. | |
Returns: | |
:class:`._LayerNorm` or None | |
""" | |
if identifier is None: | |
return None | |
elif callable(identifier): | |
return identifier | |
elif isinstance(identifier, str): | |
cls = globals().get(identifier) | |
if cls is None: | |
raise ValueError("Could not interpret normalization identifier: " + str(identifier)) | |
return cls | |
else: | |
raise ValueError("Could not interpret normalization identifier: " + str(identifier)) | |