Spaces:
Runtime error
Runtime error
File size: 3,541 Bytes
e202b16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from contextlib import nullcontext
from typing import Optional
import pytest
import torch
from torch import nn
from xformers.ops import RMSNorm
from .utils import assert_allclose
compute_capability = (0, 0)
if torch.cuda.is_available():
compute_capability = torch.cuda.get_device_capability("cuda")
cuda_sm80_only = pytest.mark.skipif(
compute_capability < (8, 0), reason="requires sm80+"
)
DTYPES = {"f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32}
class RMSNormPytorch(torch.nn.Module):
def __init__(self, dim: int, include_weight: bool = True, eps: float = 1e-6):
super().__init__()
self.eps = eps
if include_weight:
self.weight: Optional[nn.Parameter] = nn.Parameter(torch.ones(dim))
else:
self.weight = None
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
if self.weight is not None:
output = output * self.weight
return output.type_as(x)
@cuda_sm80_only
@pytest.mark.parametrize("K", [273, 4100])
@pytest.mark.parametrize("dtype", ["f16", "bf16", "f32"])
def test_forward(K: int, dtype: str):
atol = 1e-8 if dtype == "f32" else 1e-4
rtol = 1e-5 if dtype == "f32" else 0.01
torch.manual_seed(1)
B, M, K = 31, 27, K
device = torch.device("cuda")
rms_layer = RMSNorm(K).cuda()
baseline_layer = RMSNormPytorch(K).cuda()
x = torch.rand(B, M, K, device=device, dtype=DTYPES[dtype])
torch.nn.init.normal_(rms_layer.weight) # type: ignore
with torch.no_grad():
x_rms = rms_layer(x)
assert x_rms.shape == x.shape
baseline_layer.weight.copy_(rms_layer.weight) # type: ignore
baseline = baseline_layer(x)
assert_allclose(x_rms, baseline, atol=atol, rtol=rtol)
torch.nn.init.ones_(rms_layer.weight) # type: ignore
with torch.no_grad():
x_rms1 = rms_layer(x)
assert not torch.allclose(x_rms, x_rms1)
rms1_layer = RMSNorm(K, include_weight=False)
with torch.no_grad():
x_rms_1 = rms1_layer(x)
assert_allclose(x_rms1, x_rms_1, atol=atol, rtol=rtol)
@cuda_sm80_only
@pytest.mark.parametrize("K", [273, 4100])
@pytest.mark.parametrize("include_weight", [True, False])
@pytest.mark.parametrize("dtype", ["f16", "bf16", "f32"])
def test_increment(K: int, include_weight: bool, dtype: str):
atol = 1e-8 if dtype == "f32" else 1e-4
rtol = 1e-5 if dtype == "f32" else 0.01
torch.manual_seed(1)
B, M, K = 31, 27, K
device = torch.device("cuda")
dtype_ = DTYPES[dtype]
rms_layer = RMSNorm(K, include_weight=include_weight).cuda()
x_orig = torch.rand(B, M, K, device=device, dtype=dtype_)
y_orig = torch.rand(B, M, K, device=device, dtype=dtype_)
x = x_orig.clone()
y = y_orig.clone()
if include_weight:
torch.nn.init.normal_(rms_layer.weight) # type: ignore
context = torch.no_grad() if include_weight else nullcontext()
with context: # type: ignore
baseline = rms_layer(x_orig + y_orig)
out = rms_layer.increment_and_forward_(x, y)
assert_allclose(out, baseline, atol=atol, rtol=rtol)
assert_allclose(x, x_orig + y_orig, atol=atol, rtol=rtol)
assert_allclose(y, y_orig, atol=atol, rtol=rtol)
|