Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import functools | |
import random | |
from typing import cast | |
import pytest | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import xformers # noqa: F401 | |
import xformers.ops as xops | |
import xformers.ops.sp24 as sp24 | |
from .utils import assert_allclose | |
cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") | |
compute_capability = (0, 0) | |
if torch.cuda.is_available(): | |
compute_capability = torch.cuda.get_device_capability("cuda") | |
torch_compile_tests = pytest.mark.skipif( | |
torch.__version__ < "2.2.0.dev20231122", reason="requires PyTorch 2.2+" | |
) | |
requires_sp24 = pytest.mark.skipif(compute_capability < (8, 0), reason="requires sm80+") | |
requires_sp24_gemm = pytest.mark.skipif( | |
compute_capability != (8, 0), reason="requires sm80" | |
) | |
parametrize_dtype = pytest.mark.parametrize( | |
"dtype", [torch.float16, torch.bfloat16], ids=["f16", "bf16"] | |
) | |
parametrize_backend = pytest.mark.parametrize( | |
"backend", | |
[sp24.BACKEND_CUTLASS, sp24.BACKEND_CUSPARSELT] | |
if sp24._has_cusparseLt() | |
else [sp24.BACKEND_CUTLASS], | |
) | |
atol_rtol_kw = { | |
torch.float16: { | |
"rtol": 2e-3, | |
"atol": 1e-4, | |
}, | |
torch.bfloat16: { | |
"rtol": 1e-1, | |
"atol": 1e-1, | |
}, | |
} | |
def test_sparse24_largest_mask_2d() -> None: | |
inp = torch.tensor( | |
[[4, 3, 2, 1], [0, 0, 0.5, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]], | |
device="cuda", | |
dtype=torch.float16, | |
) | |
out = torch.ops.xformers.sparse24_largest_mask_2d(inp) | |
assert out.int().tolist() == [ | |
[1, 1, 0, 0], | |
[0, 1, 1, 0], | |
[0, 0, 1, 1], | |
[1, 0, 0, 1], | |
] | |
def test_autocast(dtype, backend: str) -> None: | |
N = 128 | |
inp = torch.randn([N, N], dtype=torch.float32, device="cuda") | |
W = torch.randn([N, N], dtype=torch.float32, device="cuda") | |
sInp = sp24.sparsify24(inp.to(dtype=dtype), backend=backend) | |
y = sInp @ W.to(dtype=dtype) | |
with torch.autocast("cuda", dtype=dtype): | |
sInp_ac = sp24.sparsify24(inp, backend=backend) | |
y_ac = sInp_ac @ W | |
assert_allclose( | |
sInp._sp24_to_dense(), | |
sInp_ac._sp24_to_dense(), | |
"sparse24", | |
**atol_rtol_kw[dtype], | |
) | |
assert_allclose(y, y_ac, "gemm", **atol_rtol_kw[dtype]) | |
def test_sparse24_causal1122(dtype) -> None: | |
inp = torch.tensor( | |
[[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]], | |
device="cuda", | |
dtype=dtype, | |
) | |
inp = F.pad(inp, (0, 128 - 4, 0, 128 - 4), "constant", 1) | |
sInp = sp24.sparsify24(inp, algo="causal1122") | |
mask = sInp._sp24_to_dense() / inp | |
assert mask[:4, :4].int().tolist() == [ | |
[1, 0, 0, 0], | |
[0, 0, 1, 0], | |
[0, 0, 1, 1], | |
[1, 0, 0, 1], | |
] | |
def test_sparse24_largest_abs_values_greedy(dtype, backend) -> None: | |
inp = torch.tensor( | |
[[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]], | |
device="cuda", | |
dtype=dtype, | |
) | |
inp = F.pad(inp, (0, 128 - 4, 0, 128 - 4), "constant", 1) | |
sInp = sp24.sparsify24(inp, algo="largest_abs_values_greedy", backend=backend) | |
mask = sInp._sp24_to_dense() / inp | |
assert mask[:4, :4].int().tolist() == [ | |
[1, 1, 0, 0], | |
[0, 1, 1, 0], | |
[0, 0, 1, 1], | |
[1, 0, 0, 1], | |
] | |
def test_sparse24_largest_mask_2d_notaligned(dtype) -> None: | |
inp = torch.randn([5, 5], device="cuda", dtype=dtype) | |
with pytest.raises(RuntimeError): | |
torch.ops.xformers.sparse24_largest_mask_2d(inp) | |
def test_sparse24_largest_mask_2d_big(dtype) -> None: | |
inp = torch.randn([2048, 2048], device="cuda", dtype=dtype) | |
torch.ops.xformers.sparse24_largest_mask_2d(inp) | |
def create_random_mask(shape) -> torch.Tensor: | |
r = random.Random(0) | |
mask = torch.zeros(shape, dtype=torch.bool) | |
for line in range(mask.shape[0]): | |
for col in range(0, mask.shape[1], 4): | |
sparsity = r.choice( | |
[ | |
[False, False, True, True], | |
[False, True, False, True], | |
[True, False, False, True], | |
[False, True, True, False], | |
[True, False, True, False], | |
[True, True, False, False], | |
] | |
) | |
mask[line, col : col + 4] = torch.tensor(sparsity, dtype=torch.bool) | |
return mask | |
def test_detach_requires_grad() -> None: | |
x = torch.randn([128, 64], device="cuda", dtype=torch.float16, requires_grad=True) | |
xs = sp24.sparsify24(x) | |
assert xs.requires_grad | |
# `detach` behavior | |
xs2 = xs.detach() | |
assert not xs2.requires_grad | |
assert not (xs2 * 2).requires_grad | |
xs2.requires_grad_(True) | |
assert xs2.requires_grad | |
ys = xs2 * 2 | |
assert ys.requires_grad | |
ys.backward(ys) | |
def test_detach2() -> None: | |
x = torch.randn([128, 64], device="cuda", dtype=torch.float16, requires_grad=False) | |
assert not sp24.sparsify24(x).requires_grad | |
x.requires_grad_(True) | |
xs = sp24.sparsify24(x) | |
assert xs.requires_grad | |
xs2 = xs.detach() | |
xs2.requires_grad_(True) | |
xs3 = xs2 * 2 | |
assert xs3.requires_grad | |
xs3.backward(xs3) | |
assert xs2.grad is not None | |
assert x.grad is None | |
def test_meta_pack_and_reorder() -> None: | |
mask = create_random_mask([32, 64]) | |
# Test a specific line with a known pattern | |
line = 3 | |
mask[line, :16] = torch.tensor( | |
[ | |
False, | |
True, | |
True, | |
False, # 1 << 0 | 2 << 2 | |
True, | |
True, | |
False, | |
False, # 0 << 4 | 1 << 6 | |
True, | |
False, | |
False, | |
True, # 0 << 8 | 3 << 10 | |
False, | |
True, | |
True, | |
False, # 1 << 12 | 2 << 14 | |
], | |
dtype=torch.bool, | |
) | |
packed = torch.ops.xformers._sparse24_pack_mask(mask) | |
assert packed.shape == (mask.shape[0], mask.shape[1] // 16) | |
# cast int16 -> uint16 | |
value_packed = (packed[line, 0].item() + (1 << 16)) % (1 << 16) | |
expected_value = ( | |
1 << 0 | 2 << 2 | 0 << 4 | 1 << 6 | 0 << 8 | 3 << 10 | 1 << 12 | 2 << 14 | |
) | |
assert value_packed == expected_value | |
meta_reordered = torch.ops.xformers._sparse24_reorder_meta(packed) | |
assert meta_reordered.ndim == 3 | |
assert meta_reordered.shape[0] == packed.shape[0] | |
assert (meta_reordered[0, 0, 0] == packed[0, 0]).item() | |
assert (meta_reordered[0, 1, 0] == packed[8, 0]).item() | |
assert (meta_reordered[1, 0, 0] == packed[0, 1]).item() | |
assert (meta_reordered[1, 1, 0] == packed[8, 1]).item() | |
assert (meta_reordered[2, 0, 0] == packed[16, 0]).item() | |
assert (meta_reordered[2, 1, 0] == packed[24, 0]).item() | |
# second column | |
assert (meta_reordered[0, 0, 1] == packed[0, 2]).item() | |
def test_pack_tensor_according_to_mask() -> None: | |
mask = create_random_mask([32, 64]) | |
# Test a specific line with a known pattern | |
line = 3 | |
line_pattern = [ | |
False, | |
True, | |
True, | |
False, | |
True, | |
True, | |
False, | |
False, | |
True, | |
False, | |
False, | |
True, | |
False, | |
True, | |
True, | |
False, | |
] | |
mask[line, :16] = torch.tensor(line_pattern, dtype=torch.bool) | |
packed = torch.ops.xformers._sparse24_pack_mask(mask) | |
reordered = torch.ops.xformers._sparse24_reorder_meta(packed) | |
a_full = torch.randn(mask.shape, dtype=torch.float16) | |
a_packed = torch.ops.xformers._sparse24_pack_tensor_according_to_mask( | |
a_full, reordered | |
) | |
line_full = a_full[line, :16].tolist() | |
line_packed = a_packed[line, :8].tolist() | |
line_packed_expected = [ | |
value for index, value in enumerate(line_full) if line_pattern[index] | |
] | |
assert line_packed == line_packed_expected | |
def test_sp24_gemm(dtype) -> None: | |
M, N, K = 32, 32, 64 | |
a = torch.randn([M, K], device="cuda", dtype=dtype) | |
b = torch.randn([K, N], device="cuda", dtype=dtype) | |
mask = create_random_mask([M, K]) | |
mask_packed = torch.ops.xformers._sparse24_pack_mask(mask) | |
mask_reordered = torch.ops.xformers._sparse24_reorder_meta(mask_packed) | |
packed_a = torch.ops.xformers._sparse24_pack_tensor_according_to_mask( | |
a.cpu(), mask_reordered | |
) | |
packed_a = packed_a.cuda() | |
mask_reordered = mask_reordered.cuda() | |
mask = mask.to(dtype).cuda() | |
masked_a = a * mask | |
ref_out = masked_a @ b | |
sp24_out = torch.ops.xformers._sparse24_gemm(packed_a, b, mask_reordered) | |
assert_allclose(ref_out, sp24_out, msg="sp24 GEMM", **atol_rtol_kw[dtype]) | |
def test_pack_meta_shuffle(transpose: bool) -> None: | |
local_meta = torch.zeros([4, 8, 8], dtype=torch.int64, device="cuda") | |
local_meta[:2, :2] = torch.randint( | |
0, 256, size=(2, 2, 8), dtype=torch.int64, device="cuda" | |
) | |
final_meta_tensor = torch.ops.xformers._sparse24_meta_shuffle_test( | |
local_meta, transpose | |
) | |
assert final_meta_tensor[2:, 2:].abs().max().item() == 0 | |
final_meta = final_meta_tensor.tolist() | |
def pack(line): | |
if transpose: | |
return int( | |
local_meta[0, 0, line] | |
| (local_meta[1, 0, line] << 8) | |
| (local_meta[0, 1, line] << 16) | |
| (local_meta[1, 1, line] << 24) | |
) | |
else: | |
return int( | |
local_meta[0, 0, line] | |
| (local_meta[0, 1, line] << 8) | |
| (local_meta[1, 0, line] << 16) | |
| (local_meta[1, 1, line] << 24) | |
) | |
def meta_str(m): | |
return " ".join(f"0x{mm:02x}" for mm in m.tolist()) | |
def expect_match(i, j, line): | |
value = final_meta[i][j][0] | |
expected = pack(line) | |
assert ( | |
value == expected | |
), f"""value: 0x{value:02x} (expected: 0x{expected:02x}) | |
{meta_str(local_meta[0, 0, :4])} (T0) |||| {meta_str(local_meta[0, 1, :4])} (T4) | |
{meta_str(local_meta[1, 0, :4])} (T1) |||| {meta_str(local_meta[1, 1, :4])} (T5) | |
""" | |
expect_match(0, 0, 0) # T0 | |
if transpose: | |
expect_match(1, 0, 1) # T1 | |
expect_match(0, 1, 2) # T4 | |
else: | |
expect_match(0, 1, 1) # T4 | |
expect_match(1, 0, 2) # T1 | |
expect_match(1, 1, 3) # T5 | |
def test_pack_both_ways_meta_correctness(dtype, backend) -> None: | |
M, N = 128, 256 | |
# Construct x to make sure we always have exactly 8 elements per 4x4 tile | |
a = (4 * torch.arange(8))[:, None] + torch.arange(8)[None, :] | |
a = a.repeat(M // 8, N // 8) | |
assert a.shape == (M, N) | |
a = a.cuda().to(dtype) | |
b = torch.randn([a.shape[1], 128], device="cuda", dtype=dtype) | |
a_sparse = sp24.sparsify24(a, backend=backend) | |
mask_dense = torch.ops.xformers.sparse24_largest_mask_2d(a) | |
if backend == sp24.BACKEND_CUTLASS: | |
assert isinstance(a_sparse, sp24.Sparse24TensorCutlass) | |
mask_packed = torch.ops.xformers._sparse24_pack_mask(mask_dense.cpu().bool()) | |
mask_reordered = torch.ops.xformers._sparse24_reorder_meta(mask_packed).cuda() | |
assert torch.allclose(a_sparse.meta.view(torch.short), mask_reordered) | |
ref_gemm = (mask_dense * a) @ b | |
pack_gemm = a_sparse @ b | |
assert_allclose(ref_gemm, pack_gemm, msg="sp24 GEMM", **atol_rtol_kw[dtype]) | |
def test_pack_both_ways_id(dtype) -> None: | |
N = 512 | |
torch.manual_seed(0) | |
a = torch.randn([N, N], dtype=dtype, device="cuda") | |
b = torch.eye(N, dtype=dtype, device="cuda") | |
packed, meta, packed_t, meta_t = torch.ops.xformers.sparse24_sparsify_both_ways(a)[ | |
:4 | |
] | |
# Heuristic to ensure we pack the same values | |
assert torch.allclose( | |
packed.to(torch.float64).sum(), packed_t.to(torch.float64).sum() | |
) | |
mask_dense = torch.ops.xformers.sparse24_largest_mask_2d(a.to(dtype)) | |
ref_gemm = mask_dense * a | |
# Test A@B | |
pack_gemm = torch.ops.xformers._sparse24_gemm(packed, b, meta) | |
max_diff = (ref_gemm - pack_gemm).abs().argmax() | |
assert torch.allclose( | |
ref_gemm, pack_gemm | |
), f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})" | |
# Test A.t@B | |
pack_gemm = torch.ops.xformers._sparse24_gemm(packed_t, b, meta_t) | |
pack_gemm = pack_gemm.transpose(0, 1) | |
max_diff = (ref_gemm - pack_gemm).abs().argmax() | |
assert torch.allclose( | |
ref_gemm, pack_gemm | |
), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})" | |
def test_pack_both_ways_edge_case1(dtype) -> None: | |
# In this case, the heuristic will keep 7 values out of 16 | |
# instead of 8. let's see how the kernel handles this | |
quad = torch.tensor( | |
[ | |
[2, -1, -2, -3], # Should be packed as `2 <null>` | |
[-1, 8, -1, 6], | |
[-1, -1, 4, 5], | |
[-1, 3, 7, -1], | |
], | |
dtype=dtype, | |
device="cuda", | |
) | |
a = torch.randn([32, 64], dtype=dtype, device="cuda") | |
a[:4, :4] = quad | |
packed, meta, packed_t, meta_t = torch.ops.xformers.sparse24_sparsify_both_ways(a)[ | |
:4 | |
] | |
# Check first line in A | |
assert packed[0, 0].item() == 2 | |
assert packed[0, 1].item() == 0 | |
# And first column in A.t | |
assert packed_t[0, 0].item() == 2 | |
assert packed_t[0, 1].item() == 0 | |
def test_sp24_apply(dtype) -> None: | |
M, N = 256, 1024 | |
x = torch.randn([M, N], dtype=dtype, device="cuda") | |
( | |
packed, | |
meta, | |
packed_t, | |
meta_t, | |
threads_masks, | |
) = torch.ops.xformers.sparse24_sparsify_both_ways(x) | |
packed2, _, packed_t2, _ = torch.ops.xformers.sparse24_apply(x, threads_masks) | |
assert torch.allclose(packed, packed2) | |
assert torch.allclose(packed_t, packed_t2) | |
def test_sp24_api_different_pattern(dtype) -> None: | |
M, N = 256, 256 | |
x = torch.randn([M, N], dtype=dtype, device="cuda") | |
y = torch.randn([M, N], dtype=dtype, device="cuda") | |
sx = sp24.sparsify24(x) | |
sy = sp24.sparsify24(y) | |
# Can't add with different sparsity pattern | |
with pytest.raises(ValueError): | |
sx + sy | |
# Ok, same sparsity pattern | |
assert isinstance(sx + sx, sp24.Sparse24Tensor) | |
# Ok, sharing sparsity pattern of x | |
sy2 = sp24.sparsify24_like(y, sx) | |
assert isinstance(sx + sy2, sp24.Sparse24Tensor) | |
def test_sp24_api_different_pattern_transposed(dtype) -> None: | |
N = 256 | |
x = torch.randn([N, N], dtype=dtype, device="cuda") | |
sx = sp24.sparsify24(x, backend=sp24.BACKEND_CUTLASS) | |
sxt = sx.t() | |
assert isinstance(sxt, sp24.Sparse24Tensor) | |
# Can't add with different sparsity pattern | |
with pytest.raises(ValueError): | |
sx + sxt | |
# But this should work | |
sx + sxt.t() | |
# And we should be able to sparsify with transposed pattern | |
sxt2 = sp24.sparsify24_like(x.t(), sxt) | |
assert torch.allclose(sxt2.packed, sxt.packed) | |
assert torch.allclose(sxt2.packed_t, sxt.packed_t) | |
def test_sp24_transpose_invariant(dtype, backend) -> None: | |
M, N = 128, 256 | |
torch.manual_seed(0) | |
r = random.Random(0) | |
def gen4x4(): | |
# Create a 4x4 tile that can be 24 sparsified perfectly | |
values = [ | |
[1, 1, 0, 0], | |
[0, 1, 1, 0], | |
[0, 0, 1, 1], | |
[1, 0, 0, 1], | |
] | |
c1, c2 = r.sample([0, 1, 2, 3], 2) | |
r1, r2 = r.sample([0, 1, 2, 3], 2) | |
values[r1], values[r2] = values[r2], values[r1] | |
for i in range(4): | |
values[i][c1], values[i][c2] = values[i][c2], values[i][c1] | |
return values | |
a = torch.zeros([M, N], device="cuda", dtype=torch.float16) | |
assert M % 4 == 0 and N % 4 == 0 | |
for m in range(0, M, 4): | |
for n in range(0, N, 4): | |
a[m : m + 4, n : n + 4] = torch.tensor( | |
gen4x4(), device="cuda", dtype=torch.float16 | |
) | |
a = a * torch.randn_like(a).abs() | |
# Sparsify `a`` and `a.t()` | |
a_s = sp24.sparsify24(a, backend=backend) | |
a_t_s = sp24.sparsify24(a.t().contiguous(), backend=backend) | |
assert_allclose(a_s._sp24_to_dense(), a) | |
assert_allclose(a_t_s.t()._sp24_to_dense(), a) # type: ignore | |
assert_allclose(a_t_s._sp24_to_dense().t(), a) | |
def test_sp24_matmuls(dtype) -> None: | |
M, N, K = 64, 256, 1024 | |
a = torch.randn([M, K], device="cuda", dtype=dtype) | |
b = torch.randn([K, N], device="cuda", dtype=dtype) | |
a_m = torch.ops.xformers.sparse24_largest_mask_2d(a) | |
b_m = torch.ops.xformers.sparse24_largest_mask_2d(b) | |
a_s = sp24.sparsify24(a) | |
b_s = sp24.sparsify24(b) | |
assert_allclose(a_s @ b, (a * a_m) @ b, msg="sp@dense", **atol_rtol_kw[dtype]) | |
assert_allclose(a @ b_s, a @ (b * b_m), msg="dense@sp", **atol_rtol_kw[dtype]) | |
assert_allclose( | |
a @ a_s.t(), a @ (a * a_m).t(), msg="[email protected]", **atol_rtol_kw[dtype] | |
) | |
assert_allclose( | |
a_s.t() @ a, (a * a_m).t() @ a, msg="sp.t@dense", **atol_rtol_kw[dtype] | |
) | |
def test_sp24_matmuls_mat_vec() -> None: | |
a = torch.randn([64, 128], device="cuda", dtype=torch.float16) | |
b = torch.randn([128], device="cuda", dtype=torch.float16) | |
a_m = torch.ops.xformers.sparse24_largest_mask_2d(a) | |
a_s = sp24.sparsify24(a) | |
with pytest.raises(NotImplementedError): | |
assert_allclose(a_s @ b, (a * a_m) @ b, msg="sp@dense", **atol_rtol_kw[a.dtype]) | |
def test_sp24_matmuls_bmm() -> None: | |
a = torch.randn([64, 128], device="cuda", dtype=torch.float16) | |
b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16) | |
a_m = torch.ops.xformers.sparse24_largest_mask_2d(a) | |
a_s = sp24.sparsify24(a) | |
with pytest.raises(NotImplementedError): | |
assert_allclose(a_s @ b, (a * a_m) @ b, msg="sp@dense", **atol_rtol_kw[a.dtype]) | |
def sparsify24_dense(tensor: torch.Tensor): | |
m = torch.ops.xformers.sparse24_largest_mask_2d(tensor) | |
return m * tensor | |
def test_sp24_api_mlp_act24_correctness(dtype, act) -> None: | |
B, in_ft, hid_ft, out_ft = 256, 2048, 6144, 2048 | |
torch.manual_seed(0) | |
x = torch.randn([B, in_ft], dtype=dtype, device="cuda", requires_grad=True) | |
w1 = ( | |
torch.randn([in_ft, hid_ft], dtype=dtype, device="cuda", requires_grad=False) | |
* 0.01 | |
) | |
w2 = ( | |
torch.randn([hid_ft, out_ft], dtype=dtype, device="cuda", requires_grad=False) | |
* 0.01 | |
) | |
grad = ( | |
torch.randn([B, out_ft], dtype=dtype, device="cuda", requires_grad=False) * 0.1 | |
) | |
w1.requires_grad_(True) | |
w2.requires_grad_(True) | |
params_with_grads = [x, w1, w2] | |
# Run baseline | |
x1 = x @ w1 | |
x1 = sparsify24_dense(x1) | |
x1 = act(x1) | |
out = x1 @ w2 | |
out.backward(grad) | |
grads_ref = [t.grad for t in params_with_grads] | |
for t in params_with_grads: | |
t.grad = None | |
# Run with sparsity | |
x1 = x @ w1 | |
x1 = sp24.sparsify24(x1) | |
x1 = act(x1) | |
out = x1 @ w2 | |
out.backward(grad) | |
for grad_name, grad_ref, grad_calc in zip( | |
["x", "w1", "w2"], grads_ref, [t.grad for t in params_with_grads] | |
): | |
assert grad_calc is not None, grad_name | |
assert grad_ref is not None, grad_name | |
assert_allclose(grad_calc, grad_ref, msg=grad_name, **atol_rtol_kw[dtype]) | |
def test_sp24_api_swiglu_correctness(dtype) -> None: | |
B, in_ft, hid_ft, out_ft = 256, 2048, 6144 // 2, 2048 | |
torch.manual_seed(0) | |
x = torch.randn([B, in_ft], dtype=dtype, device="cuda", requires_grad=True) | |
w1 = ( | |
torch.randn([in_ft, hid_ft], dtype=dtype, device="cuda", requires_grad=False) | |
* 0.01 | |
) | |
w2 = ( | |
torch.randn([in_ft, hid_ft], dtype=dtype, device="cuda", requires_grad=False) | |
* 0.01 | |
) | |
w3 = ( | |
torch.randn([hid_ft, out_ft], dtype=dtype, device="cuda", requires_grad=False) | |
* 0.01 | |
) | |
grad = ( | |
torch.randn([B, out_ft], dtype=dtype, device="cuda", requires_grad=False) * 0.1 | |
) | |
w1.requires_grad_(True) | |
w2.requires_grad_(True) | |
w3.requires_grad_(True) | |
params_with_grads = [x, w1, w2, w3] | |
# Run baseline | |
x1 = x @ w1 | |
x2 = x @ w2 | |
x1s = sparsify24_dense(F.silu(x1)) | |
hid = x1s * x2 | |
out = hid @ w3 | |
out.backward(grad) | |
grads_ref = [t.grad for t in params_with_grads] | |
for t in params_with_grads: | |
t.grad = None | |
# Run with sparsity | |
x1 = x @ w1 | |
x2 = x @ w2 | |
x1s = sp24.sparsify24(F.silu(x1)) | |
hid = x1s * x2 | |
out = hid @ w3 | |
out.backward(grad) | |
for grad_name, grad_ref, grad_calc in zip( | |
["x", "w1", "w2", "w3"], grads_ref, [t.grad for t in params_with_grads] | |
): | |
assert grad_calc is not None, grad_name | |
assert grad_ref is not None, grad_name | |
assert_allclose(grad_calc, grad_ref, msg=grad_name, **atol_rtol_kw[dtype]) | |
def test_not_aligned(dtype, M): | |
N, K = 64, 128 | |
A = torch.randn([M, K], dtype=dtype, device="cuda") | |
B = torch.randn([K, N], dtype=dtype, device="cuda") | |
As = sp24.sparsify24(A) | |
A = As._sp24_to_dense() | |
assert tuple(A.shape) == (M, K), A.shape | |
assert_allclose(As @ B, A @ B, msg="not aligned", **atol_rtol_kw[dtype]) | |
def test_sparsify24_like_dense(dtype, input_rowmajor, backend): | |
M, N = 128, 256 | |
if input_rowmajor: | |
x = torch.randn([M, N], dtype=dtype, device="cuda") | |
else: | |
x = torch.randn([N, M], dtype=dtype, device="cuda").t() | |
sx = sp24.sparsify24(x.contiguous(), backend=backend) | |
sx_like = sp24.sparsify24_like(x, pattern=sx, backend="dense") | |
assert_allclose( | |
sx_like, sx._sp24_to_dense(), msg="sp24_like", **atol_rtol_kw[dtype] | |
) | |
def test_sparsify24_weights(dtype, backend): | |
x = torch.randn([128, 512], dtype=dtype, device="cuda", requires_grad=True) | |
w = torch.randn([1024, 512], dtype=dtype, device="cuda", requires_grad=True) | |
flat_w = w.flatten() # FSDP-like processing | |
w = flat_w.reshape(w.shape) | |
sw = sp24.sparsify24(w, gradient="24dense", backend=backend) | |
y = x @ sw.t() | |
y.backward(y) | |
class LinearW24(torch.nn.Linear): | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
input_shape = input.shape | |
input = input.flatten(end_dim=-2) | |
dim0 = input.shape[0] | |
if dim0 % 8 != 0: | |
# NOTE: This should be torch-compiled away | |
input = F.pad(input, [0, 0, 0, -dim0 % 8]) | |
w_sparse = xops.sparsify24( | |
self.weight, | |
gradient="24dense", | |
backend="cusparselt", | |
) | |
return F.linear(input, w_sparse, self.bias,)[ | |
:dim0 | |
].unflatten(dim=0, sizes=input_shape[:-1]) | |
# XXX: This is needed to avoid a CUDA internal error | |
# See the issue here: | |
# https://github.com/pytorch/pytorch/issues/113776 | |
def _workaround_cusparselt_internal_error() -> None: | |
x0 = torch.randn([128, 128], device="cuda", dtype=torch.float16, requires_grad=True) | |
m = LinearW24(128, 128, bias=False).cuda().to(torch.float16) | |
out = m(x0) | |
out.backward(out) | |
def test_linearw24(dtype, bias: bool, aligned: bool, amp: bool) -> None: | |
_workaround_cusparselt_internal_error() | |
B, ft_in, ft_out = 64, 128, 256 | |
if not aligned: | |
B = 65 | |
model_dtype = torch.float32 if amp else dtype | |
x = torch.randn([B, ft_in], device="cuda", dtype=model_dtype, requires_grad=True) | |
grad = torch.randn([B, ft_out], device="cuda", dtype=model_dtype) | |
m = torch.nn.Linear(ft_in, ft_out, bias=bias).cuda().to(model_dtype) | |
m24 = LinearW24(ft_in, ft_out, bias=bias).cuda().to(model_dtype) | |
with torch.autocast("cuda", dtype=dtype, enabled=amp): | |
# Make weights sparse | |
state_dict = m.state_dict() | |
weight_sp24 = sp24.sparsify24(state_dict["weight"].abs()) | |
state_dict["weight"] = weight_sp24._sp24_to_dense().to(model_dtype).detach() | |
m.load_state_dict(state_dict) | |
m24.load_state_dict(state_dict) | |
# FW with dense weights | |
out = m(x) | |
# FW with sparsity | |
x24 = x.detach().requires_grad_() | |
out24 = m24(x24) | |
# Backward passes outside autocast | |
out.backward(grad) | |
out24.backward(grad) | |
assert out24.is_contiguous() | |
assert x24.grad is not None | |
assert x24.grad.is_contiguous() | |
assert m24.weight.grad is not None | |
assert m24.weight.grad.is_contiguous() | |
if bias: | |
assert m24.bias.grad is not None | |
assert_allclose(out24, out, msg="output", **atol_rtol_kw[dtype]) | |
assert x.grad is not None and x24.grad is not None | |
assert_allclose(x24.grad, x.grad, msg="x.grad", **atol_rtol_kw[dtype]) | |
assert m.weight.grad is not None | |
assert_allclose( | |
m24.weight.grad.to(dtype), | |
sp24.sparsify24_like( | |
m.weight.grad.to(dtype), pattern=weight_sp24, out_dense=True | |
), | |
msg="w.grad", | |
**atol_rtol_kw[dtype], | |
) | |
if bias: | |
assert m.bias.grad is not None | |
assert m24.bias.grad is not None | |
assert_allclose( | |
m24.bias.grad.to(dtype), | |
m.bias.grad.to(dtype), | |
msg="bias.grad", | |
**atol_rtol_kw[dtype], | |
) | |
def test_wrong_alignment_error_message() -> None: | |
A = torch.randn([128, 128], device="cuda", dtype=torch.float16) | |
B = torch.randn([128, 4], device="cuda", dtype=torch.float16) | |
A = sp24.sparsify24(A, backend="cusparselt") | |
with pytest.raises(NotImplementedError, match="aligned to 8"): | |
A @ B | |
def test_min_alignment() -> None: | |
A = torch.randn([128, 128], device="cuda", dtype=torch.float16) | |
B = torch.randn([128, 8], device="cuda", dtype=torch.float16) | |
A = sp24.sparsify24(A, backend="cusparselt") | |
assert_allclose(A @ B, A._sp24_to_dense() @ B, "output", **atol_rtol_kw[A.dtype]) | |
def test_wrong_dtype_error_message() -> None: | |
A = torch.randn([128, 128], device="cuda", dtype=torch.float16) | |
B = torch.randn([128, 16], device="cuda", dtype=torch.float32) | |
A = sp24.sparsify24(A, backend="cusparselt") | |
with pytest.raises(NotImplementedError, match="the same data type"): | |
A @ B | |
def test_linear_dispatch_inference_mode(backend: str, with_bias: bool) -> None: | |
B, ft_in, ft_out = 128, 256, 512 | |
x = torch.randn([B, ft_in], device="cuda", dtype=torch.float16) | |
weight = torch.randn([ft_out, ft_in], device="cuda", dtype=torch.float16) | |
bias = ( | |
torch.randn([ft_out], device="cuda", dtype=torch.float16) if with_bias else None | |
) | |
w_sparse = sp24.sparsify24( | |
weight, | |
gradient="24dense", | |
backend=backend, | |
) | |
# NOTE: When in `inference_mode`, PyTorch no longer dispatches to `addmm`, but to `linear` | |
# so we need to support that as well | |
with torch.inference_mode(): | |
# Does not support bias at the moment in CUTLASS backend | |
if bias is not None and backend == sp24.BACKEND_CUTLASS: | |
with pytest.raises(NotImplementedError): | |
F.linear(x, w_sparse, bias) | |
return | |
out = F.linear(x, w_sparse, bias) | |
out_ref = F.linear(x, w_sparse._sp24_to_dense(), bias) | |
assert_allclose(out, out_ref, msg="output", **atol_rtol_kw[x.dtype]) | |
def test_sp24_meta() -> None: | |
x = torch.randn([1024, 512], device="meta", dtype=torch.float16) | |
x_s = sp24.sparsify24(x, backend="cusparselt") | |
assert x_s.shape == x.shape | |
x_s_t = x_s.t() | |
assert x_s_t.shape == x.t().shape | |
def test_sp24_compile(backend) -> None: | |
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True) | |
e = torch.eye(x.shape[0], x.shape[0], device="cuda", dtype=torch.float16) | |
def fn(x, e): | |
y = sp24.sparsify24(x, backend=backend, gradient="24dense") | |
y = y.t() | |
return x @ y | |
# Eager | |
output = fn(x, e) | |
output.backward(output) | |
# Torch compile | |
output = torch.compile(fn)(x, e) | |
output.backward(output) | |
class _TransformerFFN(nn.Module): | |
def __init__( | |
self, | |
in_features: int, | |
hidden_features=None, | |
out_features=None, | |
act_layer=nn.GELU, | |
bias: bool = True, | |
linear_cls=nn.Linear, | |
) -> None: | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
self.fc1 = linear_cls(in_features, hidden_features, bias=bias) | |
self.act = act_layer() | |
self.fc2 = linear_cls(hidden_features, out_features, bias=bias) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.fc2(x) | |
return x | |
def test_linearw24_block_compile() -> None: | |
# TODO: Parametrize on `dtype` when torch.compile gets faster | |
# currently takes ~5s per test | |
dtype = torch.bfloat16 | |
B, FT_IN, FT_HIDDEN = 31, 512, 2048 | |
_workaround_cusparselt_internal_error() | |
m = _TransformerFFN(FT_IN, FT_HIDDEN, linear_cls=LinearW24).to("cuda").to(dtype) | |
m_c = _TransformerFFN(FT_IN, FT_HIDDEN, linear_cls=LinearW24).to("cuda").to(dtype) | |
m_c.load_state_dict(m.state_dict()) | |
m_c = cast(_TransformerFFN, torch.compile(m_c)) | |
x, grad = [torch.randn([B, FT_IN], dtype=dtype, device="cuda") for _ in range(2)] | |
x = x.requires_grad_() | |
out = m(x) | |
out.backward(grad) | |
x_c = x.detach().requires_grad_() | |
out_c = m_c(x_c) | |
out_c.backward(grad) | |
assert_allclose(out_c, out, "output", **atol_rtol_kw[dtype]) | |
assert x_c.grad is not None and x.grad is not None | |
assert_allclose(x_c.grad, x.grad, "output", **atol_rtol_kw[dtype]) | |
for param_name, param_ref, param_c in [ | |
["fc1.weight", m.fc1.weight, m_c.fc1.weight], | |
["fc1.bias", m.fc1.bias, m_c.fc1.bias], | |
["fc2.weight", m.fc2.weight, m_c.fc2.weight], | |
["fc2.bias", m.fc2.bias, m_c.fc2.bias], | |
]: | |
assert param_ref.grad is not None and param_c.grad is not None, param_name | |
assert_allclose(param_c.grad, param_ref.grad, param_name, **atol_rtol_kw[dtype]) | |
def test_sp24_ste(): | |
x = torch.randn([512, 512], dtype=torch.float16, device="cuda", requires_grad=True) | |
grad = torch.randn_like(x) | |
spX = sp24.sparsify24(x, gradient=sp24.GRADIENT_STE) | |
spX.backward(grad) | |
assert_allclose(x.grad, grad, "grad") | |
def test_sparsify24_ste(dtype): | |
x = torch.randn([512, 512], dtype=dtype, device="cuda", requires_grad=True) | |
y = torch.randn([512, 512], dtype=dtype, device="cuda", requires_grad=True) | |
mul0 = 2.0 # (numbers that have an exact representation in f16) | |
mul1 = 0.5 | |
spX = sp24.sparsify24_ste(x, bw_mul0=mul0, bw_mul1=mul1) | |
spX.backward(y) | |
spYd = sp24.sparsify24_like(y, pattern=spX)._sp24_to_dense() | |
ref = mul1 * (spYd) + mul0 * (y - spYd) | |
assert_allclose(x.grad, ref, "grad") | |