Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import pytest | |
import torch | |
from xformers.components import NormalizationType, PreNorm | |
class Passthrough(torch.nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
def forward(self, *args): | |
return args | |
def test_pre_norm(normalization): | |
# Check that passing the same tensor a bunch of times skips the extra normalizations | |
x = torch.rand((3, 3), requires_grad=True) | |
wrap = PreNorm( | |
d_norm=3, sublayer=Passthrough(), normalization=normalization, use_triton=False | |
) | |
outputs = wrap(inputs=[x, x, x]) | |
assert id(outputs[0]) == id(outputs[1]) | |
# Check the BW pass | |
torch.sum(outputs[0]).backward() | |