Flexstorydiff / xformers /tests /test_swiglu.py
FlexTheAi's picture
Upload folder using huggingface_hub
e202b16 verified
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import copy
import functools
import random
from contextlib import nullcontext
from typing import ContextManager, Optional, Sequence, cast
import pytest
import torch
import xformers
import xformers.ops.swiglu_op as xsw
from .utils import disable_tf32
cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
if torch.cuda.is_available():
_devices = ["cuda"]
_is_sm80 = torch.cuda.get_device_capability(_devices[0])[0] >= 8
else:
_devices = []
_is_sm80 = False
cuda_sm80_only = pytest.mark.skipif(not _is_sm80, reason="requires sm80+")
torch_compile_tests = pytest.mark.skipif(
torch.__version__ < "2.2.0.dev20231122", reason="requires PyTorch 2.2+"
)
disable_on_rocm = pytest.mark.skipif(
not not torch.version.hip, reason="could not be done on ROCM"
)
def assert_allclose(
# The output of the tested function
out: torch.Tensor,
# The output of the reference implementation
ref: torch.Tensor,
# The output of the reference implementation in f32
ref32: Optional[torch.Tensor] = None,
msg: str = "failed",
atol: Optional[float] = None,
rtol: Optional[float] = None,
) -> None:
"""
Improved version of
```
assert torch.allclose(out, ref)
```
Except that we provide useful error message, and also compare
to the output of the f32 calculation.
"""
out = out.float()
ref = ref.float()
if atol is None:
atol = 1e-8
if rtol is None:
rtol = 1e-5
assert out.shape == ref.shape
compare_to = ref32 if ref32 is not None else ref
assert out.shape == compare_to.shape
if torch.allclose(out, ref, rtol=rtol, atol=atol) or (
ref32 is not None and torch.allclose(out, ref32, rtol=rtol, atol=atol)
):
return
flatten_diff = ((out - compare_to).abs() - atol - compare_to.abs() * rtol).flatten()
max_pos = flatten_diff.argmax()
if ref32 is not None:
flatten_diff_vsf32 = ((ref - ref32).abs() - atol - ref32.abs() * rtol).flatten()
max_pos_vsf32 = flatten_diff_vsf32.argmax()
assert False, (
f"{msg}: "
f"out={out.flatten()[max_pos]} and ref32={compare_to.flatten()[max_pos]} (diff={flatten_diff[max_pos]} > 0)"
f"/ atol={atol}, rtol={rtol}.\n"
f"NOTE: ref vs ref_f32:\n"
f"ref={ref.flatten()[max_pos_vsf32]} and ref32={ref32.flatten()[max_pos_vsf32]} "
f"(diff={flatten_diff_vsf32[max_pos_vsf32]})"
)
else:
assert False, (
f"{msg}: "
f"out={out.flatten()[max_pos]} and ref={compare_to.flatten()[max_pos]} (diff={flatten_diff[max_pos]} > 0)"
f"/ atol={atol}, rtol={rtol}"
)
def generate_test_shapes():
shapes = [
# Format: [inp.shape[0], inp.shape[1], hidden.shape[1]]
# ViT-Giant
(9456, 1536, 2736),
(4440, 1536, 2736),
(4728, 1536, 2736),
# GPT-3 (small)
(2048, 2048, 5632),
# Chinchilla
(2048, 8192, 22016),
]
# Add some random shapes
r = random.Random(0)
for _ in range(20):
shapes.append(
(r.randint(1, 1000) * 8, r.randint(1, 1000) * 8, r.randint(1, 512) * 8)
)
return shapes
# Switch between these shape initialisations ...
_test_shapes = list(generate_test_shapes())
_test_shapes_ids = [str(s) for s in _test_shapes]
_dtypes = [torch.float16]
if _is_sm80:
_dtypes += [torch.bfloat16]
_ops: Sequence[xsw.SwiGLUOp] = [xsw.SwiGLUFusedOp, xsw.SwiGLUPackedFusedOp]
FORWARD_ATOL = {torch.float: 2e-6, torch.half: 1e-2, torch.bfloat16: 1e-2}
FORWARD_RTOL = {torch.float: 1e-5, torch.half: 4e-3, torch.bfloat16: 4e-3}
BACKWARD_ATOL = {
torch.float: 3e-4,
torch.half: 0.5,
torch.bfloat16: 4.0, # !!
}
BACKWARD_RTOL = {
torch.float: 2e-3,
torch.half: 1e-2,
torch.bfloat16: 4e-2,
}
@functools.lru_cache(maxsize=1)
def create_module_cached(**kwargs) -> xsw.SwiGLU:
return xsw.SwiGLU(**kwargs)
@disable_tf32
@disable_on_rocm
@pytest.mark.parametrize("autocast", [False, True], ids=["regular", "autocast"])
@pytest.mark.parametrize("op", _ops, ids=[x.NAME for x in _ops])
@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes])
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("bias", [False, True], ids=["nobias", "bias"])
@pytest.mark.parametrize("pack_weights", [False, True], ids=["regular", "packed"])
@pytest.mark.parametrize(
"shape",
_test_shapes,
ids=_test_shapes_ids,
)
def test_forward_backward(
shape,
device,
op,
dtype,
autocast: bool,
pack_weights: bool,
bias: bool,
):
torch.manual_seed(shape[0] * shape[1] * shape[2])
if not op.supports(
xsw.SwiGLUOpDispatch(
device=device,
dtype=dtype,
dtype_autocast_gpu=dtype if autocast and device == "cuda" else None,
packed_weights=pack_weights,
bias_enabled=bias,
)
):
pytest.skip("Not supported by operator")
inp_model_dtype = torch.float if autocast else dtype
x = torch.randn(shape[:2], device=device, dtype=inp_model_dtype)
module = copy.deepcopy(
create_module_cached(
in_features=shape[1],
hidden_features=shape[2],
bias=bias,
_pack_weights=pack_weights,
)
)
x_f32: Optional[torch.Tensor]
ref_f32: Optional[torch.Tensor]
module_f32: Optional[torch.nn.Module]
if dtype != torch.float:
x_f32, module_f32 = x.to(device).to(torch.float), module.to(device)
x_f32.requires_grad_()
ref_f32 = module_f32(x_f32)
else:
x_f32, module_f32, ref_f32 = None, None, None
x, module = x.to(device).to(inp_model_dtype), module.to(device).to(inp_model_dtype)
x.requires_grad_()
# Forward
cm = cast(
ContextManager,
torch.autocast("cuda", dtype=dtype) if autocast else nullcontext(),
)
with cm:
ref = module(x)
out = xsw.swiglu(x, *module._ordered_params(), op=op)
if ref_f32 is None:
ref_f32 = ref
assert_allclose(
out, ref, ref_f32, "fw", atol=FORWARD_ATOL[dtype], rtol=FORWARD_RTOL[dtype]
)
# Backward
grad = torch.randn_like(ref)
def backward_gather_grads(inp, output):
output.backward(grad.to(output.dtype))
grads = {}
for name, param in module.named_parameters():
grads[name] = param.grad.clone()
param.grad = None
grads["x"] = inp.grad.clone()
inp.grad = None
return grads
grads_ref = backward_gather_grads(x, ref)
grads_out = backward_gather_grads(x, out)
grads_ref32 = (
backward_gather_grads(x_f32, ref_f32) if module_f32 is not None else grads_ref
)
assert list(grads_ref.keys()) == list(grads_out.keys())
for name, gref in grads_ref.items():
gout = grads_out[name]
assert_allclose(
gout,
gref,
grads_ref32.get(name),
f"{name}.grad",
atol=BACKWARD_ATOL[dtype],
rtol=BACKWARD_RTOL[dtype],
)
# Ensure `gout >> atol`, so that the test is meaningful
assert gout.norm(2) > BACKWARD_ATOL[dtype] / BACKWARD_RTOL[dtype]
@torch_compile_tests
@cuda_sm80_only
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes])
@pytest.mark.parametrize("bias", [False, True], ids=["nobias", "bias"])
def test_swiglu_compile(
device,
dtype,
bias: bool,
):
op = xsw.SwiGLUPackedFusedOp
shape = [2048, 2048, 5632]
# Eager
mod = copy.deepcopy(
create_module_cached(
in_features=shape[1],
hidden_features=shape[2],
bias=bias,
_pack_weights=True,
)
)
mod = cast(xsw.SwiGLU, mod)
mod.op = op
mod = mod.to(device).to(dtype)
# Torch compile
mod_c = cast(xsw.SwiGLU, torch.compile(mod, fullgraph=True, dynamic=True))
assert mod.w12 is not None
assert mod_c.w12 is not None
x = torch.randn(shape[:2], device=device, dtype=dtype, requires_grad=True)
x_c = x.detach().requires_grad_()
grad = torch.randn(shape[:2], device=device, dtype=dtype, requires_grad=False) * 0.1
# Forward passes
output = mod(x)
output_c = mod_c(x_c)
# Backward passes
output.backward(grad)
output_c.backward(grad)
assert_allclose(
output, output_c, msg="fw", atol=FORWARD_ATOL[dtype], rtol=FORWARD_RTOL[dtype]
)
assert x_c.grad is not None and x.grad is not None
assert_allclose(
x_c.grad,
x.grad,
msg="grad",
atol=BACKWARD_ATOL[dtype],
rtol=BACKWARD_RTOL[dtype],
)
assert mod.w12.weight.grad is not None and mod_c.w12.weight.grad is not None
assert_allclose(
mod.w12.weight.grad,
mod_c.w12.weight.grad,
msg="w12.grad",
atol=BACKWARD_ATOL[dtype],
rtol=BACKWARD_RTOL[dtype],
)
assert mod.w3.weight.grad is not None and mod_c.w3.weight.grad is not None
assert_allclose(
mod.w3.weight.grad,
mod_c.w3.weight.grad,
msg="w3.grad",
atol=BACKWARD_ATOL[dtype],
rtol=BACKWARD_RTOL[dtype],
)
if bias:
assert mod.w12.bias.grad is not None and mod_c.w12.bias.grad is not None
assert_allclose(
mod.w12.bias.grad,
mod_c.w12.bias.grad,
msg="w12.bias.grad",
atol=BACKWARD_ATOL[dtype],
rtol=BACKWARD_RTOL[dtype],
)
assert mod.w3.bias.grad is not None and mod_c.w3.bias.grad is not None
assert_allclose(
mod.w3.bias.grad,
mod_c.w3.bias.grad,
msg="w12.bias.grad",
atol=BACKWARD_ATOL[dtype],
rtol=BACKWARD_RTOL[dtype],
)
@disable_tf32
@torch.inference_mode()
@torch_compile_tests
@cuda_sm80_only
@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes])
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("bias", [False, True], ids=["nobias", "bias"])
def test_dual_gemm_silu_identity_mul_compile(dtype, device, bias) -> None:
N, M, H = (2048, 2048, 5632)
x = torch.randn([N, M], device=device, dtype=dtype, requires_grad=False)
w1 = torch.randn([H, M], device=device, dtype=dtype, requires_grad=False)
w2 = torch.randn([H, M], device=device, dtype=dtype, requires_grad=False)
b1: Optional[torch.Tensor] = None
b2: Optional[torch.Tensor] = None
if bias:
b1 = torch.randn([H], device=device, dtype=dtype, requires_grad=False)
b2 = torch.randn([H], device=device, dtype=dtype, requires_grad=False)
DualGemmSiluOp = xformers.ops.common.get_xformers_operator(
"dual_gemm_silu_identity_mul"
)
def fn(x):
x1, x2, x4 = DualGemmSiluOp(x, w1, b1, w2, b2)
return [x1, x2, x4]
# Eager
output = fn(x)
# Torch compile
opt_output = torch.compile(fn, fullgraph=True, dynamic=True)(x)
for a, b in zip(output, opt_output):
assert_allclose(
a, b, msg="fw", atol=FORWARD_ATOL[dtype], rtol=FORWARD_RTOL[dtype]
)
@disable_tf32
@torch.inference_mode()
@cuda_sm80_only
@torch_compile_tests
@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes])
@pytest.mark.parametrize("device", _devices)
def test_gemm_fused_operand_sum_compile(dtype, device) -> None:
shape = [2048, 2048, 5632]
x = torch.randn(
[shape[0], shape[2]], device=device, dtype=dtype, requires_grad=False
)
dy = torch.randn(shape[:2], device=device, dtype=dtype, requires_grad=False)
db = torch.empty([dy.shape[1]], dtype=dy.dtype, device=dy.device)
dw = torch.empty([dy.shape[1], x.shape[1]], dtype=dy.dtype, device=dy.device)
GemmFusedSumOp = xformers.ops.common.get_xformers_operator("gemm_fused_operand_sum")
def fn(x):
GemmFusedSumOp(dy.transpose(-2, -1), x, dw, db)
return [dw, db]
# Eager
output = fn(x)
# Torch compile
opt_output = torch.compile(fn, fullgraph=True, dynamic=True)(x)
for a, b in zip(output, opt_output):
assert_allclose(
a, b, msg="fw", atol=FORWARD_ATOL[dtype], rtol=FORWARD_RTOL[dtype]
)
@disable_tf32
@torch.inference_mode()
@torch_compile_tests
@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes])
@pytest.mark.parametrize("device", _devices)
def test_silu_bw_fused_compile(dtype, device) -> None:
shape = [2048, 2048]
x1 = torch.randn(shape, device=device, dtype=dtype, requires_grad=False)
x2 = torch.randn(shape, device=device, dtype=dtype, requires_grad=False)
dx4 = torch.randn(shape, device=device, dtype=dtype, requires_grad=False)
SiluBWFusedOp = xformers.ops.common.get_xformers_operator("silu_bw_fused")
def fn(x1, x2, dx4):
dx1dx2, x4 = SiluBWFusedOp(x1, x2, dx4)
return [dx1dx2, x4]
# Eager
with torch.autocast("cuda", dtype=dtype):
output = fn(x1, x2, dx4)
# Torch compile
opt_output = torch.compile(fn, fullgraph=True, dynamic=True)(x1, x2, dx4)
for a, b in zip(output, opt_output):
assert_allclose(
a, b, msg="fw", atol=FORWARD_ATOL[dtype], rtol=FORWARD_RTOL[dtype]
)
@disable_tf32
@cuda_only
@cuda_sm80_only
def test_autocast_silu_bw_fused_compile() -> None:
shape = [2048, 2048]
device = "cuda"
dtype = torch.float32
x1 = torch.randn(shape, device=device, dtype=dtype)
x2 = torch.randn(shape, device=device, dtype=dtype)
dx4 = torch.randn(shape, device=device, dtype=dtype)
SiluBWFusedOp = xformers.ops.common.get_xformers_operator("silu_bw_fused")
def fn(x1, x2, dx4):
dx1dx2, x4 = SiluBWFusedOp(x1, x2, dx4)
return [dx1dx2, x4]
output = fn(x1, x2, dx4)
# Autocast
with torch.autocast("cuda", dtype=dtype):
output_ac = fn(x1, x2, dx4)
for a, b in zip(output, output_ac):
assert_allclose(
a, b, msg="fw", atol=FORWARD_ATOL[dtype], rtol=FORWARD_RTOL[dtype]
)