Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import copy | |
import functools | |
import random | |
from contextlib import nullcontext | |
from typing import ContextManager, Optional, Sequence, cast | |
import pytest | |
import torch | |
import xformers | |
import xformers.ops.swiglu_op as xsw | |
from .utils import disable_tf32 | |
cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") | |
if torch.cuda.is_available(): | |
_devices = ["cuda"] | |
_is_sm80 = torch.cuda.get_device_capability(_devices[0])[0] >= 8 | |
else: | |
_devices = [] | |
_is_sm80 = False | |
cuda_sm80_only = pytest.mark.skipif(not _is_sm80, reason="requires sm80+") | |
torch_compile_tests = pytest.mark.skipif( | |
torch.__version__ < "2.2.0.dev20231122", reason="requires PyTorch 2.2+" | |
) | |
disable_on_rocm = pytest.mark.skipif( | |
not not torch.version.hip, reason="could not be done on ROCM" | |
) | |
def assert_allclose( | |
# The output of the tested function | |
out: torch.Tensor, | |
# The output of the reference implementation | |
ref: torch.Tensor, | |
# The output of the reference implementation in f32 | |
ref32: Optional[torch.Tensor] = None, | |
msg: str = "failed", | |
atol: Optional[float] = None, | |
rtol: Optional[float] = None, | |
) -> None: | |
""" | |
Improved version of | |
``` | |
assert torch.allclose(out, ref) | |
``` | |
Except that we provide useful error message, and also compare | |
to the output of the f32 calculation. | |
""" | |
out = out.float() | |
ref = ref.float() | |
if atol is None: | |
atol = 1e-8 | |
if rtol is None: | |
rtol = 1e-5 | |
assert out.shape == ref.shape | |
compare_to = ref32 if ref32 is not None else ref | |
assert out.shape == compare_to.shape | |
if torch.allclose(out, ref, rtol=rtol, atol=atol) or ( | |
ref32 is not None and torch.allclose(out, ref32, rtol=rtol, atol=atol) | |
): | |
return | |
flatten_diff = ((out - compare_to).abs() - atol - compare_to.abs() * rtol).flatten() | |
max_pos = flatten_diff.argmax() | |
if ref32 is not None: | |
flatten_diff_vsf32 = ((ref - ref32).abs() - atol - ref32.abs() * rtol).flatten() | |
max_pos_vsf32 = flatten_diff_vsf32.argmax() | |
assert False, ( | |
f"{msg}: " | |
f"out={out.flatten()[max_pos]} and ref32={compare_to.flatten()[max_pos]} (diff={flatten_diff[max_pos]} > 0)" | |
f"/ atol={atol}, rtol={rtol}.\n" | |
f"NOTE: ref vs ref_f32:\n" | |
f"ref={ref.flatten()[max_pos_vsf32]} and ref32={ref32.flatten()[max_pos_vsf32]} " | |
f"(diff={flatten_diff_vsf32[max_pos_vsf32]})" | |
) | |
else: | |
assert False, ( | |
f"{msg}: " | |
f"out={out.flatten()[max_pos]} and ref={compare_to.flatten()[max_pos]} (diff={flatten_diff[max_pos]} > 0)" | |
f"/ atol={atol}, rtol={rtol}" | |
) | |
def generate_test_shapes(): | |
shapes = [ | |
# Format: [inp.shape[0], inp.shape[1], hidden.shape[1]] | |
# ViT-Giant | |
(9456, 1536, 2736), | |
(4440, 1536, 2736), | |
(4728, 1536, 2736), | |
# GPT-3 (small) | |
(2048, 2048, 5632), | |
# Chinchilla | |
(2048, 8192, 22016), | |
] | |
# Add some random shapes | |
r = random.Random(0) | |
for _ in range(20): | |
shapes.append( | |
(r.randint(1, 1000) * 8, r.randint(1, 1000) * 8, r.randint(1, 512) * 8) | |
) | |
return shapes | |
# Switch between these shape initialisations ... | |
_test_shapes = list(generate_test_shapes()) | |
_test_shapes_ids = [str(s) for s in _test_shapes] | |
_dtypes = [torch.float16] | |
if _is_sm80: | |
_dtypes += [torch.bfloat16] | |
_ops: Sequence[xsw.SwiGLUOp] = [xsw.SwiGLUFusedOp, xsw.SwiGLUPackedFusedOp] | |
FORWARD_ATOL = {torch.float: 2e-6, torch.half: 1e-2, torch.bfloat16: 1e-2} | |
FORWARD_RTOL = {torch.float: 1e-5, torch.half: 4e-3, torch.bfloat16: 4e-3} | |
BACKWARD_ATOL = { | |
torch.float: 3e-4, | |
torch.half: 0.5, | |
torch.bfloat16: 4.0, # !! | |
} | |
BACKWARD_RTOL = { | |
torch.float: 2e-3, | |
torch.half: 1e-2, | |
torch.bfloat16: 4e-2, | |
} | |
def create_module_cached(**kwargs) -> xsw.SwiGLU: | |
return xsw.SwiGLU(**kwargs) | |
def test_forward_backward( | |
shape, | |
device, | |
op, | |
dtype, | |
autocast: bool, | |
pack_weights: bool, | |
bias: bool, | |
): | |
torch.manual_seed(shape[0] * shape[1] * shape[2]) | |
if not op.supports( | |
xsw.SwiGLUOpDispatch( | |
device=device, | |
dtype=dtype, | |
dtype_autocast_gpu=dtype if autocast and device == "cuda" else None, | |
packed_weights=pack_weights, | |
bias_enabled=bias, | |
) | |
): | |
pytest.skip("Not supported by operator") | |
inp_model_dtype = torch.float if autocast else dtype | |
x = torch.randn(shape[:2], device=device, dtype=inp_model_dtype) | |
module = copy.deepcopy( | |
create_module_cached( | |
in_features=shape[1], | |
hidden_features=shape[2], | |
bias=bias, | |
_pack_weights=pack_weights, | |
) | |
) | |
x_f32: Optional[torch.Tensor] | |
ref_f32: Optional[torch.Tensor] | |
module_f32: Optional[torch.nn.Module] | |
if dtype != torch.float: | |
x_f32, module_f32 = x.to(device).to(torch.float), module.to(device) | |
x_f32.requires_grad_() | |
ref_f32 = module_f32(x_f32) | |
else: | |
x_f32, module_f32, ref_f32 = None, None, None | |
x, module = x.to(device).to(inp_model_dtype), module.to(device).to(inp_model_dtype) | |
x.requires_grad_() | |
# Forward | |
cm = cast( | |
ContextManager, | |
torch.autocast("cuda", dtype=dtype) if autocast else nullcontext(), | |
) | |
with cm: | |
ref = module(x) | |
out = xsw.swiglu(x, *module._ordered_params(), op=op) | |
if ref_f32 is None: | |
ref_f32 = ref | |
assert_allclose( | |
out, ref, ref_f32, "fw", atol=FORWARD_ATOL[dtype], rtol=FORWARD_RTOL[dtype] | |
) | |
# Backward | |
grad = torch.randn_like(ref) | |
def backward_gather_grads(inp, output): | |
output.backward(grad.to(output.dtype)) | |
grads = {} | |
for name, param in module.named_parameters(): | |
grads[name] = param.grad.clone() | |
param.grad = None | |
grads["x"] = inp.grad.clone() | |
inp.grad = None | |
return grads | |
grads_ref = backward_gather_grads(x, ref) | |
grads_out = backward_gather_grads(x, out) | |
grads_ref32 = ( | |
backward_gather_grads(x_f32, ref_f32) if module_f32 is not None else grads_ref | |
) | |
assert list(grads_ref.keys()) == list(grads_out.keys()) | |
for name, gref in grads_ref.items(): | |
gout = grads_out[name] | |
assert_allclose( | |
gout, | |
gref, | |
grads_ref32.get(name), | |
f"{name}.grad", | |
atol=BACKWARD_ATOL[dtype], | |
rtol=BACKWARD_RTOL[dtype], | |
) | |
# Ensure `gout >> atol`, so that the test is meaningful | |
assert gout.norm(2) > BACKWARD_ATOL[dtype] / BACKWARD_RTOL[dtype] | |
def test_swiglu_compile( | |
device, | |
dtype, | |
bias: bool, | |
): | |
op = xsw.SwiGLUPackedFusedOp | |
shape = [2048, 2048, 5632] | |
# Eager | |
mod = copy.deepcopy( | |
create_module_cached( | |
in_features=shape[1], | |
hidden_features=shape[2], | |
bias=bias, | |
_pack_weights=True, | |
) | |
) | |
mod = cast(xsw.SwiGLU, mod) | |
mod.op = op | |
mod = mod.to(device).to(dtype) | |
# Torch compile | |
mod_c = cast(xsw.SwiGLU, torch.compile(mod, fullgraph=True, dynamic=True)) | |
assert mod.w12 is not None | |
assert mod_c.w12 is not None | |
x = torch.randn(shape[:2], device=device, dtype=dtype, requires_grad=True) | |
x_c = x.detach().requires_grad_() | |
grad = torch.randn(shape[:2], device=device, dtype=dtype, requires_grad=False) * 0.1 | |
# Forward passes | |
output = mod(x) | |
output_c = mod_c(x_c) | |
# Backward passes | |
output.backward(grad) | |
output_c.backward(grad) | |
assert_allclose( | |
output, output_c, msg="fw", atol=FORWARD_ATOL[dtype], rtol=FORWARD_RTOL[dtype] | |
) | |
assert x_c.grad is not None and x.grad is not None | |
assert_allclose( | |
x_c.grad, | |
x.grad, | |
msg="grad", | |
atol=BACKWARD_ATOL[dtype], | |
rtol=BACKWARD_RTOL[dtype], | |
) | |
assert mod.w12.weight.grad is not None and mod_c.w12.weight.grad is not None | |
assert_allclose( | |
mod.w12.weight.grad, | |
mod_c.w12.weight.grad, | |
msg="w12.grad", | |
atol=BACKWARD_ATOL[dtype], | |
rtol=BACKWARD_RTOL[dtype], | |
) | |
assert mod.w3.weight.grad is not None and mod_c.w3.weight.grad is not None | |
assert_allclose( | |
mod.w3.weight.grad, | |
mod_c.w3.weight.grad, | |
msg="w3.grad", | |
atol=BACKWARD_ATOL[dtype], | |
rtol=BACKWARD_RTOL[dtype], | |
) | |
if bias: | |
assert mod.w12.bias.grad is not None and mod_c.w12.bias.grad is not None | |
assert_allclose( | |
mod.w12.bias.grad, | |
mod_c.w12.bias.grad, | |
msg="w12.bias.grad", | |
atol=BACKWARD_ATOL[dtype], | |
rtol=BACKWARD_RTOL[dtype], | |
) | |
assert mod.w3.bias.grad is not None and mod_c.w3.bias.grad is not None | |
assert_allclose( | |
mod.w3.bias.grad, | |
mod_c.w3.bias.grad, | |
msg="w12.bias.grad", | |
atol=BACKWARD_ATOL[dtype], | |
rtol=BACKWARD_RTOL[dtype], | |
) | |
def test_dual_gemm_silu_identity_mul_compile(dtype, device, bias) -> None: | |
N, M, H = (2048, 2048, 5632) | |
x = torch.randn([N, M], device=device, dtype=dtype, requires_grad=False) | |
w1 = torch.randn([H, M], device=device, dtype=dtype, requires_grad=False) | |
w2 = torch.randn([H, M], device=device, dtype=dtype, requires_grad=False) | |
b1: Optional[torch.Tensor] = None | |
b2: Optional[torch.Tensor] = None | |
if bias: | |
b1 = torch.randn([H], device=device, dtype=dtype, requires_grad=False) | |
b2 = torch.randn([H], device=device, dtype=dtype, requires_grad=False) | |
DualGemmSiluOp = xformers.ops.common.get_xformers_operator( | |
"dual_gemm_silu_identity_mul" | |
) | |
def fn(x): | |
x1, x2, x4 = DualGemmSiluOp(x, w1, b1, w2, b2) | |
return [x1, x2, x4] | |
# Eager | |
output = fn(x) | |
# Torch compile | |
opt_output = torch.compile(fn, fullgraph=True, dynamic=True)(x) | |
for a, b in zip(output, opt_output): | |
assert_allclose( | |
a, b, msg="fw", atol=FORWARD_ATOL[dtype], rtol=FORWARD_RTOL[dtype] | |
) | |
def test_gemm_fused_operand_sum_compile(dtype, device) -> None: | |
shape = [2048, 2048, 5632] | |
x = torch.randn( | |
[shape[0], shape[2]], device=device, dtype=dtype, requires_grad=False | |
) | |
dy = torch.randn(shape[:2], device=device, dtype=dtype, requires_grad=False) | |
db = torch.empty([dy.shape[1]], dtype=dy.dtype, device=dy.device) | |
dw = torch.empty([dy.shape[1], x.shape[1]], dtype=dy.dtype, device=dy.device) | |
GemmFusedSumOp = xformers.ops.common.get_xformers_operator("gemm_fused_operand_sum") | |
def fn(x): | |
GemmFusedSumOp(dy.transpose(-2, -1), x, dw, db) | |
return [dw, db] | |
# Eager | |
output = fn(x) | |
# Torch compile | |
opt_output = torch.compile(fn, fullgraph=True, dynamic=True)(x) | |
for a, b in zip(output, opt_output): | |
assert_allclose( | |
a, b, msg="fw", atol=FORWARD_ATOL[dtype], rtol=FORWARD_RTOL[dtype] | |
) | |
def test_silu_bw_fused_compile(dtype, device) -> None: | |
shape = [2048, 2048] | |
x1 = torch.randn(shape, device=device, dtype=dtype, requires_grad=False) | |
x2 = torch.randn(shape, device=device, dtype=dtype, requires_grad=False) | |
dx4 = torch.randn(shape, device=device, dtype=dtype, requires_grad=False) | |
SiluBWFusedOp = xformers.ops.common.get_xformers_operator("silu_bw_fused") | |
def fn(x1, x2, dx4): | |
dx1dx2, x4 = SiluBWFusedOp(x1, x2, dx4) | |
return [dx1dx2, x4] | |
# Eager | |
with torch.autocast("cuda", dtype=dtype): | |
output = fn(x1, x2, dx4) | |
# Torch compile | |
opt_output = torch.compile(fn, fullgraph=True, dynamic=True)(x1, x2, dx4) | |
for a, b in zip(output, opt_output): | |
assert_allclose( | |
a, b, msg="fw", atol=FORWARD_ATOL[dtype], rtol=FORWARD_RTOL[dtype] | |
) | |
def test_autocast_silu_bw_fused_compile() -> None: | |
shape = [2048, 2048] | |
device = "cuda" | |
dtype = torch.float32 | |
x1 = torch.randn(shape, device=device, dtype=dtype) | |
x2 = torch.randn(shape, device=device, dtype=dtype) | |
dx4 = torch.randn(shape, device=device, dtype=dtype) | |
SiluBWFusedOp = xformers.ops.common.get_xformers_operator("silu_bw_fused") | |
def fn(x1, x2, dx4): | |
dx1dx2, x4 = SiluBWFusedOp(x1, x2, dx4) | |
return [dx1dx2, x4] | |
output = fn(x1, x2, dx4) | |
# Autocast | |
with torch.autocast("cuda", dtype=dtype): | |
output_ac = fn(x1, x2, dx4) | |
for a, b in zip(output, output_ac): | |
assert_allclose( | |
a, b, msg="fw", atol=FORWARD_ATOL[dtype], rtol=FORWARD_RTOL[dtype] | |
) | |