Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Normalization modules.""" | |
import typing as tp | |
import torch | |
from torch import nn | |
class ConvLayerNorm(nn.LayerNorm): | |
""" | |
Convolution-friendly LayerNorm that moves channels to last dimensions | |
before running the normalization and moves them back to original position right after. | |
""" | |
def __init__( | |
self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs | |
): | |
super().__init__(normalized_shape, **kwargs) | |
def forward(self, x): | |
assert x.ndim == 3 # (n_batch, n_channels, n_samples) | |
x = x.transpose(1, 2) | |
x = super().forward(x) | |
x = x.transpose(1, 2) | |
return x | |