DeepLearning101's picture
Upload 16 files
b6c45cb verified
raw
history blame contribute delete
4.89 kB
from functools import partial
import torch
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm
EPS = 1e-8
class _LayerNorm(nn.Module):
"""Layer Normalization base class."""
def __init__(self, channel_size):
super(_LayerNorm, self).__init__()
self.channel_size = channel_size
self.gamma = nn.Parameter(torch.ones(channel_size), requires_grad=True)
self.beta = nn.Parameter(torch.zeros(channel_size), requires_grad=True)
def apply_gain_and_bias(self, normed_x):
""" Assumes input of size `[batch, chanel, *]`. """
return (self.gamma * normed_x.transpose(1, -1) + self.beta).transpose(1, -1)
class GlobLN(_LayerNorm):
"""Global Layer Normalization (globLN)."""
def forward(self, x):
"""Applies forward pass.
Works for any input size > 2D.
Args:
x (:class:`torch.Tensor`): Shape `[batch, chan, *]`
Returns:
:class:`torch.Tensor`: gLN_x `[batch, chan, *]`
"""
dims = list(range(1, len(x.shape)))
mean = x.mean(dim=dims, keepdim=True)
var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True)
return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
class ChanLN(_LayerNorm):
"""Channel-wise Layer Normalization (chanLN)."""
def forward(self, x):
"""Applies forward pass.
Works for any input size > 2D.
Args:
x (:class:`torch.Tensor`): `[batch, chan, *]`
Returns:
:class:`torch.Tensor`: chanLN_x `[batch, chan, *]`
"""
mean = torch.mean(x, dim=1, keepdim=True)
var = torch.var(x, dim=1, keepdim=True, unbiased=False)
return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
class CumLN(_LayerNorm):
"""Cumulative Global layer normalization(cumLN)."""
def forward(self, x):
"""
Args:
x (:class:`torch.Tensor`): Shape `[batch, channels, length]`
Returns:
:class:`torch.Tensor`: cumLN_x `[batch, channels, length]`
"""
batch, chan, spec_len = x.size()
cum_sum = torch.cumsum(x.sum(1, keepdim=True), dim=-1)
cum_pow_sum = torch.cumsum(x.pow(2).sum(1, keepdim=True), dim=-1)
cnt = torch.arange(start=chan, end=chan * (spec_len + 1), step=chan, dtype=x.dtype).view(
1, 1, -1
)
cum_mean = cum_sum / cnt
cum_var = cum_pow_sum - cum_mean.pow(2)
return self.apply_gain_and_bias((x - cum_mean) / (cum_var + EPS).sqrt())
class FeatsGlobLN(_LayerNorm):
"""feature-wise global Layer Normalization (FeatsGlobLN).
Applies normalization over frames for each channel."""
def forward(self, x):
"""Applies forward pass.
Works for any input size > 2D.
Args:
x (:class:`torch.Tensor`): `[batch, chan, time]`
Returns:
:class:`torch.Tensor`: chanLN_x `[batch, chan, time]`
"""
stop = len(x.size())
dims = list(range(2, stop))
mean = torch.mean(x, dim=dims, keepdim=True)
var = torch.var(x, dim=dims, keepdim=True, unbiased=False)
return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
class BatchNorm(_BatchNorm):
"""Wrapper class for pytorch BatchNorm1D and BatchNorm2D"""
def _check_input_dim(self, input):
if input.dim() < 2 or input.dim() > 4:
raise ValueError("expected 4D or 3D input (got {}D input)".format(input.dim()))
# Aliases.
gLN = GlobLN
fgLN = FeatsGlobLN
cLN = ChanLN
cgLN = CumLN
bN = BatchNorm
def register_norm(custom_norm):
"""Register a custom norm, gettable with `norms.get`.
Args:
custom_norm: Custom norm to register.
"""
if custom_norm.__name__ in globals().keys() or custom_norm.__name__.lower() in globals().keys():
raise ValueError(f"Norm {custom_norm.__name__} already exists. Choose another name.")
globals().update({custom_norm.__name__: custom_norm})
def get(identifier):
"""Returns a norm class from a string. Returns its input if it
is callable (already a :class:`._LayerNorm` for example).
Args:
identifier (str or Callable or None): the norm identifier.
Returns:
:class:`._LayerNorm` or None
"""
if identifier is None:
return None
elif callable(identifier):
return identifier
elif isinstance(identifier, str):
cls = globals().get(identifier)
if cls is None:
raise ValueError("Could not interpret normalization identifier: " + str(identifier))
return cls
else:
raise ValueError("Could not interpret normalization identifier: " + str(identifier))