# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import functools import random from typing import cast import pytest import torch import torch.nn as nn import torch.nn.functional as F import xformers # noqa: F401 import xformers.ops as xops import xformers.ops.sp24 as sp24 from .utils import assert_allclose cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") compute_capability = (0, 0) if torch.cuda.is_available(): compute_capability = torch.cuda.get_device_capability("cuda") torch_compile_tests = pytest.mark.skipif( torch.__version__ < "2.2.0.dev20231122", reason="requires PyTorch 2.2+" ) requires_sp24 = pytest.mark.skipif(compute_capability < (8, 0), reason="requires sm80+") requires_sp24_gemm = pytest.mark.skipif( compute_capability != (8, 0), reason="requires sm80" ) parametrize_dtype = pytest.mark.parametrize( "dtype", [torch.float16, torch.bfloat16], ids=["f16", "bf16"] ) parametrize_backend = pytest.mark.parametrize( "backend", [sp24.BACKEND_CUTLASS, sp24.BACKEND_CUSPARSELT] if sp24._has_cusparseLt() else [sp24.BACKEND_CUTLASS], ) atol_rtol_kw = { torch.float16: { "rtol": 2e-3, "atol": 1e-4, }, torch.bfloat16: { "rtol": 1e-1, "atol": 1e-1, }, } @cuda_only def test_sparse24_largest_mask_2d() -> None: inp = torch.tensor( [[4, 3, 2, 1], [0, 0, 0.5, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]], device="cuda", dtype=torch.float16, ) out = torch.ops.xformers.sparse24_largest_mask_2d(inp) assert out.int().tolist() == [ [1, 1, 0, 0], [0, 1, 1, 0], [0, 0, 1, 1], [1, 0, 0, 1], ] @requires_sp24_gemm @parametrize_dtype @parametrize_backend def test_autocast(dtype, backend: str) -> None: N = 128 inp = torch.randn([N, N], dtype=torch.float32, device="cuda") W = torch.randn([N, N], dtype=torch.float32, device="cuda") sInp = sp24.sparsify24(inp.to(dtype=dtype), backend=backend) y = sInp @ W.to(dtype=dtype) with torch.autocast("cuda", dtype=dtype): sInp_ac = sp24.sparsify24(inp, backend=backend) y_ac = sInp_ac @ W assert_allclose( sInp._sp24_to_dense(), sInp_ac._sp24_to_dense(), "sparse24", **atol_rtol_kw[dtype], ) assert_allclose(y, y_ac, "gemm", **atol_rtol_kw[dtype]) @requires_sp24_gemm @parametrize_dtype def test_sparse24_causal1122(dtype) -> None: inp = torch.tensor( [[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]], device="cuda", dtype=dtype, ) inp = F.pad(inp, (0, 128 - 4, 0, 128 - 4), "constant", 1) sInp = sp24.sparsify24(inp, algo="causal1122") mask = sInp._sp24_to_dense() / inp assert mask[:4, :4].int().tolist() == [ [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 1], [1, 0, 0, 1], ] @requires_sp24_gemm @parametrize_dtype @parametrize_backend def test_sparse24_largest_abs_values_greedy(dtype, backend) -> None: inp = torch.tensor( [[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]], device="cuda", dtype=dtype, ) inp = F.pad(inp, (0, 128 - 4, 0, 128 - 4), "constant", 1) sInp = sp24.sparsify24(inp, algo="largest_abs_values_greedy", backend=backend) mask = sInp._sp24_to_dense() / inp assert mask[:4, :4].int().tolist() == [ [1, 1, 0, 0], [0, 1, 1, 0], [0, 0, 1, 1], [1, 0, 0, 1], ] @cuda_only @parametrize_dtype def test_sparse24_largest_mask_2d_notaligned(dtype) -> None: inp = torch.randn([5, 5], device="cuda", dtype=dtype) with pytest.raises(RuntimeError): torch.ops.xformers.sparse24_largest_mask_2d(inp) @cuda_only @parametrize_dtype def test_sparse24_largest_mask_2d_big(dtype) -> None: inp = torch.randn([2048, 2048], device="cuda", dtype=dtype) torch.ops.xformers.sparse24_largest_mask_2d(inp) def create_random_mask(shape) -> torch.Tensor: r = random.Random(0) mask = torch.zeros(shape, dtype=torch.bool) for line in range(mask.shape[0]): for col in range(0, mask.shape[1], 4): sparsity = r.choice( [ [False, False, True, True], [False, True, False, True], [True, False, False, True], [False, True, True, False], [True, False, True, False], [True, True, False, False], ] ) mask[line, col : col + 4] = torch.tensor(sparsity, dtype=torch.bool) return mask @cuda_only def test_detach_requires_grad() -> None: x = torch.randn([128, 64], device="cuda", dtype=torch.float16, requires_grad=True) xs = sp24.sparsify24(x) assert xs.requires_grad # `detach` behavior xs2 = xs.detach() assert not xs2.requires_grad assert not (xs2 * 2).requires_grad xs2.requires_grad_(True) assert xs2.requires_grad ys = xs2 * 2 assert ys.requires_grad ys.backward(ys) @cuda_only def test_detach2() -> None: x = torch.randn([128, 64], device="cuda", dtype=torch.float16, requires_grad=False) assert not sp24.sparsify24(x).requires_grad x.requires_grad_(True) xs = sp24.sparsify24(x) assert xs.requires_grad xs2 = xs.detach() xs2.requires_grad_(True) xs3 = xs2 * 2 assert xs3.requires_grad xs3.backward(xs3) assert xs2.grad is not None assert x.grad is None @cuda_only def test_meta_pack_and_reorder() -> None: mask = create_random_mask([32, 64]) # Test a specific line with a known pattern line = 3 mask[line, :16] = torch.tensor( [ False, True, True, False, # 1 << 0 | 2 << 2 True, True, False, False, # 0 << 4 | 1 << 6 True, False, False, True, # 0 << 8 | 3 << 10 False, True, True, False, # 1 << 12 | 2 << 14 ], dtype=torch.bool, ) packed = torch.ops.xformers._sparse24_pack_mask(mask) assert packed.shape == (mask.shape[0], mask.shape[1] // 16) # cast int16 -> uint16 value_packed = (packed[line, 0].item() + (1 << 16)) % (1 << 16) expected_value = ( 1 << 0 | 2 << 2 | 0 << 4 | 1 << 6 | 0 << 8 | 3 << 10 | 1 << 12 | 2 << 14 ) assert value_packed == expected_value meta_reordered = torch.ops.xformers._sparse24_reorder_meta(packed) assert meta_reordered.ndim == 3 assert meta_reordered.shape[0] == packed.shape[0] assert (meta_reordered[0, 0, 0] == packed[0, 0]).item() assert (meta_reordered[0, 1, 0] == packed[8, 0]).item() assert (meta_reordered[1, 0, 0] == packed[0, 1]).item() assert (meta_reordered[1, 1, 0] == packed[8, 1]).item() assert (meta_reordered[2, 0, 0] == packed[16, 0]).item() assert (meta_reordered[2, 1, 0] == packed[24, 0]).item() # second column assert (meta_reordered[0, 0, 1] == packed[0, 2]).item() @cuda_only def test_pack_tensor_according_to_mask() -> None: mask = create_random_mask([32, 64]) # Test a specific line with a known pattern line = 3 line_pattern = [ False, True, True, False, True, True, False, False, True, False, False, True, False, True, True, False, ] mask[line, :16] = torch.tensor(line_pattern, dtype=torch.bool) packed = torch.ops.xformers._sparse24_pack_mask(mask) reordered = torch.ops.xformers._sparse24_reorder_meta(packed) a_full = torch.randn(mask.shape, dtype=torch.float16) a_packed = torch.ops.xformers._sparse24_pack_tensor_according_to_mask( a_full, reordered ) line_full = a_full[line, :16].tolist() line_packed = a_packed[line, :8].tolist() line_packed_expected = [ value for index, value in enumerate(line_full) if line_pattern[index] ] assert line_packed == line_packed_expected @requires_sp24_gemm @parametrize_dtype def test_sp24_gemm(dtype) -> None: M, N, K = 32, 32, 64 a = torch.randn([M, K], device="cuda", dtype=dtype) b = torch.randn([K, N], device="cuda", dtype=dtype) mask = create_random_mask([M, K]) mask_packed = torch.ops.xformers._sparse24_pack_mask(mask) mask_reordered = torch.ops.xformers._sparse24_reorder_meta(mask_packed) packed_a = torch.ops.xformers._sparse24_pack_tensor_according_to_mask( a.cpu(), mask_reordered ) packed_a = packed_a.cuda() mask_reordered = mask_reordered.cuda() mask = mask.to(dtype).cuda() masked_a = a * mask ref_out = masked_a @ b sp24_out = torch.ops.xformers._sparse24_gemm(packed_a, b, mask_reordered) assert_allclose(ref_out, sp24_out, msg="sp24 GEMM", **atol_rtol_kw[dtype]) @cuda_only @pytest.mark.parametrize("transpose", [True, False]) def test_pack_meta_shuffle(transpose: bool) -> None: local_meta = torch.zeros([4, 8, 8], dtype=torch.int64, device="cuda") local_meta[:2, :2] = torch.randint( 0, 256, size=(2, 2, 8), dtype=torch.int64, device="cuda" ) final_meta_tensor = torch.ops.xformers._sparse24_meta_shuffle_test( local_meta, transpose ) assert final_meta_tensor[2:, 2:].abs().max().item() == 0 final_meta = final_meta_tensor.tolist() def pack(line): if transpose: return int( local_meta[0, 0, line] | (local_meta[1, 0, line] << 8) | (local_meta[0, 1, line] << 16) | (local_meta[1, 1, line] << 24) ) else: return int( local_meta[0, 0, line] | (local_meta[0, 1, line] << 8) | (local_meta[1, 0, line] << 16) | (local_meta[1, 1, line] << 24) ) def meta_str(m): return " ".join(f"0x{mm:02x}" for mm in m.tolist()) def expect_match(i, j, line): value = final_meta[i][j][0] expected = pack(line) assert ( value == expected ), f"""value: 0x{value:02x} (expected: 0x{expected:02x}) {meta_str(local_meta[0, 0, :4])} (T0) |||| {meta_str(local_meta[0, 1, :4])} (T4) {meta_str(local_meta[1, 0, :4])} (T1) |||| {meta_str(local_meta[1, 1, :4])} (T5) """ expect_match(0, 0, 0) # T0 if transpose: expect_match(1, 0, 1) # T1 expect_match(0, 1, 2) # T4 else: expect_match(0, 1, 1) # T4 expect_match(1, 0, 2) # T1 expect_match(1, 1, 3) # T5 @requires_sp24_gemm @parametrize_dtype @parametrize_backend def test_pack_both_ways_meta_correctness(dtype, backend) -> None: M, N = 128, 256 # Construct x to make sure we always have exactly 8 elements per 4x4 tile a = (4 * torch.arange(8))[:, None] + torch.arange(8)[None, :] a = a.repeat(M // 8, N // 8) assert a.shape == (M, N) a = a.cuda().to(dtype) b = torch.randn([a.shape[1], 128], device="cuda", dtype=dtype) a_sparse = sp24.sparsify24(a, backend=backend) mask_dense = torch.ops.xformers.sparse24_largest_mask_2d(a) if backend == sp24.BACKEND_CUTLASS: assert isinstance(a_sparse, sp24.Sparse24TensorCutlass) mask_packed = torch.ops.xformers._sparse24_pack_mask(mask_dense.cpu().bool()) mask_reordered = torch.ops.xformers._sparse24_reorder_meta(mask_packed).cuda() assert torch.allclose(a_sparse.meta.view(torch.short), mask_reordered) ref_gemm = (mask_dense * a) @ b pack_gemm = a_sparse @ b assert_allclose(ref_gemm, pack_gemm, msg="sp24 GEMM", **atol_rtol_kw[dtype]) @requires_sp24_gemm @parametrize_dtype def test_pack_both_ways_id(dtype) -> None: N = 512 torch.manual_seed(0) a = torch.randn([N, N], dtype=dtype, device="cuda") b = torch.eye(N, dtype=dtype, device="cuda") packed, meta, packed_t, meta_t = torch.ops.xformers.sparse24_sparsify_both_ways(a)[ :4 ] # Heuristic to ensure we pack the same values assert torch.allclose( packed.to(torch.float64).sum(), packed_t.to(torch.float64).sum() ) mask_dense = torch.ops.xformers.sparse24_largest_mask_2d(a.to(dtype)) ref_gemm = mask_dense * a # Test A@B pack_gemm = torch.ops.xformers._sparse24_gemm(packed, b, meta) max_diff = (ref_gemm - pack_gemm).abs().argmax() assert torch.allclose( ref_gemm, pack_gemm ), f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})" # Test A.t@B pack_gemm = torch.ops.xformers._sparse24_gemm(packed_t, b, meta_t) pack_gemm = pack_gemm.transpose(0, 1) max_diff = (ref_gemm - pack_gemm).abs().argmax() assert torch.allclose( ref_gemm, pack_gemm ), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})" @cuda_only @parametrize_dtype def test_pack_both_ways_edge_case1(dtype) -> None: # In this case, the heuristic will keep 7 values out of 16 # instead of 8. let's see how the kernel handles this quad = torch.tensor( [ [2, -1, -2, -3], # Should be packed as `2 ` [-1, 8, -1, 6], [-1, -1, 4, 5], [-1, 3, 7, -1], ], dtype=dtype, device="cuda", ) a = torch.randn([32, 64], dtype=dtype, device="cuda") a[:4, :4] = quad packed, meta, packed_t, meta_t = torch.ops.xformers.sparse24_sparsify_both_ways(a)[ :4 ] # Check first line in A assert packed[0, 0].item() == 2 assert packed[0, 1].item() == 0 # And first column in A.t assert packed_t[0, 0].item() == 2 assert packed_t[0, 1].item() == 0 @cuda_only @parametrize_dtype def test_sp24_apply(dtype) -> None: M, N = 256, 1024 x = torch.randn([M, N], dtype=dtype, device="cuda") ( packed, meta, packed_t, meta_t, threads_masks, ) = torch.ops.xformers.sparse24_sparsify_both_ways(x) packed2, _, packed_t2, _ = torch.ops.xformers.sparse24_apply(x, threads_masks) assert torch.allclose(packed, packed2) assert torch.allclose(packed_t, packed_t2) @cuda_only @parametrize_dtype def test_sp24_api_different_pattern(dtype) -> None: M, N = 256, 256 x = torch.randn([M, N], dtype=dtype, device="cuda") y = torch.randn([M, N], dtype=dtype, device="cuda") sx = sp24.sparsify24(x) sy = sp24.sparsify24(y) # Can't add with different sparsity pattern with pytest.raises(ValueError): sx + sy # Ok, same sparsity pattern assert isinstance(sx + sx, sp24.Sparse24Tensor) # Ok, sharing sparsity pattern of x sy2 = sp24.sparsify24_like(y, sx) assert isinstance(sx + sy2, sp24.Sparse24Tensor) @cuda_only @parametrize_dtype def test_sp24_api_different_pattern_transposed(dtype) -> None: N = 256 x = torch.randn([N, N], dtype=dtype, device="cuda") sx = sp24.sparsify24(x, backend=sp24.BACKEND_CUTLASS) sxt = sx.t() assert isinstance(sxt, sp24.Sparse24Tensor) # Can't add with different sparsity pattern with pytest.raises(ValueError): sx + sxt # But this should work sx + sxt.t() # And we should be able to sparsify with transposed pattern sxt2 = sp24.sparsify24_like(x.t(), sxt) assert torch.allclose(sxt2.packed, sxt.packed) assert torch.allclose(sxt2.packed_t, sxt.packed_t) @requires_sp24_gemm @parametrize_dtype @parametrize_backend def test_sp24_transpose_invariant(dtype, backend) -> None: M, N = 128, 256 torch.manual_seed(0) r = random.Random(0) def gen4x4(): # Create a 4x4 tile that can be 24 sparsified perfectly values = [ [1, 1, 0, 0], [0, 1, 1, 0], [0, 0, 1, 1], [1, 0, 0, 1], ] c1, c2 = r.sample([0, 1, 2, 3], 2) r1, r2 = r.sample([0, 1, 2, 3], 2) values[r1], values[r2] = values[r2], values[r1] for i in range(4): values[i][c1], values[i][c2] = values[i][c2], values[i][c1] return values a = torch.zeros([M, N], device="cuda", dtype=torch.float16) assert M % 4 == 0 and N % 4 == 0 for m in range(0, M, 4): for n in range(0, N, 4): a[m : m + 4, n : n + 4] = torch.tensor( gen4x4(), device="cuda", dtype=torch.float16 ) a = a * torch.randn_like(a).abs() # Sparsify `a`` and `a.t()` a_s = sp24.sparsify24(a, backend=backend) a_t_s = sp24.sparsify24(a.t().contiguous(), backend=backend) assert_allclose(a_s._sp24_to_dense(), a) assert_allclose(a_t_s.t()._sp24_to_dense(), a) # type: ignore assert_allclose(a_t_s._sp24_to_dense().t(), a) @requires_sp24_gemm @parametrize_dtype def test_sp24_matmuls(dtype) -> None: M, N, K = 64, 256, 1024 a = torch.randn([M, K], device="cuda", dtype=dtype) b = torch.randn([K, N], device="cuda", dtype=dtype) a_m = torch.ops.xformers.sparse24_largest_mask_2d(a) b_m = torch.ops.xformers.sparse24_largest_mask_2d(b) a_s = sp24.sparsify24(a) b_s = sp24.sparsify24(b) assert_allclose(a_s @ b, (a * a_m) @ b, msg="sp@dense", **atol_rtol_kw[dtype]) assert_allclose(a @ b_s, a @ (b * b_m), msg="dense@sp", **atol_rtol_kw[dtype]) assert_allclose( a @ a_s.t(), a @ (a * a_m).t(), msg="dense@sp.t", **atol_rtol_kw[dtype] ) assert_allclose( a_s.t() @ a, (a * a_m).t() @ a, msg="sp.t@dense", **atol_rtol_kw[dtype] ) @requires_sp24 def test_sp24_matmuls_mat_vec() -> None: a = torch.randn([64, 128], device="cuda", dtype=torch.float16) b = torch.randn([128], device="cuda", dtype=torch.float16) a_m = torch.ops.xformers.sparse24_largest_mask_2d(a) a_s = sp24.sparsify24(a) with pytest.raises(NotImplementedError): assert_allclose(a_s @ b, (a * a_m) @ b, msg="sp@dense", **atol_rtol_kw[a.dtype]) @requires_sp24 def test_sp24_matmuls_bmm() -> None: a = torch.randn([64, 128], device="cuda", dtype=torch.float16) b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16) a_m = torch.ops.xformers.sparse24_largest_mask_2d(a) a_s = sp24.sparsify24(a) with pytest.raises(NotImplementedError): assert_allclose(a_s @ b, (a * a_m) @ b, msg="sp@dense", **atol_rtol_kw[a.dtype]) def sparsify24_dense(tensor: torch.Tensor): m = torch.ops.xformers.sparse24_largest_mask_2d(tensor) return m * tensor @requires_sp24_gemm @parametrize_dtype @pytest.mark.parametrize("act", [F.gelu, F.relu]) def test_sp24_api_mlp_act24_correctness(dtype, act) -> None: B, in_ft, hid_ft, out_ft = 256, 2048, 6144, 2048 torch.manual_seed(0) x = torch.randn([B, in_ft], dtype=dtype, device="cuda", requires_grad=True) w1 = ( torch.randn([in_ft, hid_ft], dtype=dtype, device="cuda", requires_grad=False) * 0.01 ) w2 = ( torch.randn([hid_ft, out_ft], dtype=dtype, device="cuda", requires_grad=False) * 0.01 ) grad = ( torch.randn([B, out_ft], dtype=dtype, device="cuda", requires_grad=False) * 0.1 ) w1.requires_grad_(True) w2.requires_grad_(True) params_with_grads = [x, w1, w2] # Run baseline x1 = x @ w1 x1 = sparsify24_dense(x1) x1 = act(x1) out = x1 @ w2 out.backward(grad) grads_ref = [t.grad for t in params_with_grads] for t in params_with_grads: t.grad = None # Run with sparsity x1 = x @ w1 x1 = sp24.sparsify24(x1) x1 = act(x1) out = x1 @ w2 out.backward(grad) for grad_name, grad_ref, grad_calc in zip( ["x", "w1", "w2"], grads_ref, [t.grad for t in params_with_grads] ): assert grad_calc is not None, grad_name assert grad_ref is not None, grad_name assert_allclose(grad_calc, grad_ref, msg=grad_name, **atol_rtol_kw[dtype]) @requires_sp24_gemm @parametrize_dtype def test_sp24_api_swiglu_correctness(dtype) -> None: B, in_ft, hid_ft, out_ft = 256, 2048, 6144 // 2, 2048 torch.manual_seed(0) x = torch.randn([B, in_ft], dtype=dtype, device="cuda", requires_grad=True) w1 = ( torch.randn([in_ft, hid_ft], dtype=dtype, device="cuda", requires_grad=False) * 0.01 ) w2 = ( torch.randn([in_ft, hid_ft], dtype=dtype, device="cuda", requires_grad=False) * 0.01 ) w3 = ( torch.randn([hid_ft, out_ft], dtype=dtype, device="cuda", requires_grad=False) * 0.01 ) grad = ( torch.randn([B, out_ft], dtype=dtype, device="cuda", requires_grad=False) * 0.1 ) w1.requires_grad_(True) w2.requires_grad_(True) w3.requires_grad_(True) params_with_grads = [x, w1, w2, w3] # Run baseline x1 = x @ w1 x2 = x @ w2 x1s = sparsify24_dense(F.silu(x1)) hid = x1s * x2 out = hid @ w3 out.backward(grad) grads_ref = [t.grad for t in params_with_grads] for t in params_with_grads: t.grad = None # Run with sparsity x1 = x @ w1 x2 = x @ w2 x1s = sp24.sparsify24(F.silu(x1)) hid = x1s * x2 out = hid @ w3 out.backward(grad) for grad_name, grad_ref, grad_calc in zip( ["x", "w1", "w2", "w3"], grads_ref, [t.grad for t in params_with_grads] ): assert grad_calc is not None, grad_name assert grad_ref is not None, grad_name assert_allclose(grad_calc, grad_ref, msg=grad_name, **atol_rtol_kw[dtype]) @requires_sp24_gemm @parametrize_dtype @pytest.mark.parametrize("M", [1, 8, 26, 31, 32, 48, 63]) def test_not_aligned(dtype, M): N, K = 64, 128 A = torch.randn([M, K], dtype=dtype, device="cuda") B = torch.randn([K, N], dtype=dtype, device="cuda") As = sp24.sparsify24(A) A = As._sp24_to_dense() assert tuple(A.shape) == (M, K), A.shape assert_allclose(As @ B, A @ B, msg="not aligned", **atol_rtol_kw[dtype]) @requires_sp24_gemm @parametrize_dtype @parametrize_backend @pytest.mark.parametrize("input_rowmajor", [True, False]) def test_sparsify24_like_dense(dtype, input_rowmajor, backend): M, N = 128, 256 if input_rowmajor: x = torch.randn([M, N], dtype=dtype, device="cuda") else: x = torch.randn([N, M], dtype=dtype, device="cuda").t() sx = sp24.sparsify24(x.contiguous(), backend=backend) sx_like = sp24.sparsify24_like(x, pattern=sx, backend="dense") assert_allclose( sx_like, sx._sp24_to_dense(), msg="sp24_like", **atol_rtol_kw[dtype] ) @requires_sp24_gemm @parametrize_dtype @parametrize_backend def test_sparsify24_weights(dtype, backend): x = torch.randn([128, 512], dtype=dtype, device="cuda", requires_grad=True) w = torch.randn([1024, 512], dtype=dtype, device="cuda", requires_grad=True) flat_w = w.flatten() # FSDP-like processing w = flat_w.reshape(w.shape) sw = sp24.sparsify24(w, gradient="24dense", backend=backend) y = x @ sw.t() y.backward(y) class LinearW24(torch.nn.Linear): def forward(self, input: torch.Tensor) -> torch.Tensor: input_shape = input.shape input = input.flatten(end_dim=-2) dim0 = input.shape[0] if dim0 % 8 != 0: # NOTE: This should be torch-compiled away input = F.pad(input, [0, 0, 0, -dim0 % 8]) w_sparse = xops.sparsify24( self.weight, gradient="24dense", backend="cusparselt", ) return F.linear(input, w_sparse, self.bias,)[ :dim0 ].unflatten(dim=0, sizes=input_shape[:-1]) # XXX: This is needed to avoid a CUDA internal error # See the issue here: # https://github.com/pytorch/pytorch/issues/113776 @functools.lru_cache() def _workaround_cusparselt_internal_error() -> None: x0 = torch.randn([128, 128], device="cuda", dtype=torch.float16, requires_grad=True) m = LinearW24(128, 128, bias=False).cuda().to(torch.float16) out = m(x0) out.backward(out) @requires_sp24 @parametrize_dtype @pytest.mark.skipif(not sp24._has_cusparseLt(), reason="requires cusparselt") @pytest.mark.parametrize("bias", [False, True], ids=["", "bias"]) @pytest.mark.parametrize("aligned", [False, True], ids=["misaligned", ""]) @pytest.mark.parametrize("amp", [False, True], ids=["", "amp"]) def test_linearw24(dtype, bias: bool, aligned: bool, amp: bool) -> None: _workaround_cusparselt_internal_error() B, ft_in, ft_out = 64, 128, 256 if not aligned: B = 65 model_dtype = torch.float32 if amp else dtype x = torch.randn([B, ft_in], device="cuda", dtype=model_dtype, requires_grad=True) grad = torch.randn([B, ft_out], device="cuda", dtype=model_dtype) m = torch.nn.Linear(ft_in, ft_out, bias=bias).cuda().to(model_dtype) m24 = LinearW24(ft_in, ft_out, bias=bias).cuda().to(model_dtype) with torch.autocast("cuda", dtype=dtype, enabled=amp): # Make weights sparse state_dict = m.state_dict() weight_sp24 = sp24.sparsify24(state_dict["weight"].abs()) state_dict["weight"] = weight_sp24._sp24_to_dense().to(model_dtype).detach() m.load_state_dict(state_dict) m24.load_state_dict(state_dict) # FW with dense weights out = m(x) # FW with sparsity x24 = x.detach().requires_grad_() out24 = m24(x24) # Backward passes outside autocast out.backward(grad) out24.backward(grad) assert out24.is_contiguous() assert x24.grad is not None assert x24.grad.is_contiguous() assert m24.weight.grad is not None assert m24.weight.grad.is_contiguous() if bias: assert m24.bias.grad is not None assert_allclose(out24, out, msg="output", **atol_rtol_kw[dtype]) assert x.grad is not None and x24.grad is not None assert_allclose(x24.grad, x.grad, msg="x.grad", **atol_rtol_kw[dtype]) assert m.weight.grad is not None assert_allclose( m24.weight.grad.to(dtype), sp24.sparsify24_like( m.weight.grad.to(dtype), pattern=weight_sp24, out_dense=True ), msg="w.grad", **atol_rtol_kw[dtype], ) if bias: assert m.bias.grad is not None assert m24.bias.grad is not None assert_allclose( m24.bias.grad.to(dtype), m.bias.grad.to(dtype), msg="bias.grad", **atol_rtol_kw[dtype], ) @requires_sp24 @pytest.mark.skipif(not sp24._has_cusparseLt(), reason="requires cusparselt") def test_wrong_alignment_error_message() -> None: A = torch.randn([128, 128], device="cuda", dtype=torch.float16) B = torch.randn([128, 4], device="cuda", dtype=torch.float16) A = sp24.sparsify24(A, backend="cusparselt") with pytest.raises(NotImplementedError, match="aligned to 8"): A @ B @requires_sp24 @pytest.mark.skipif(not sp24._has_cusparseLt(), reason="requires cusparselt") def test_min_alignment() -> None: A = torch.randn([128, 128], device="cuda", dtype=torch.float16) B = torch.randn([128, 8], device="cuda", dtype=torch.float16) A = sp24.sparsify24(A, backend="cusparselt") assert_allclose(A @ B, A._sp24_to_dense() @ B, "output", **atol_rtol_kw[A.dtype]) @requires_sp24 @pytest.mark.skipif(not sp24._has_cusparseLt(), reason="requires cusparselt") def test_wrong_dtype_error_message() -> None: A = torch.randn([128, 128], device="cuda", dtype=torch.float16) B = torch.randn([128, 16], device="cuda", dtype=torch.float32) A = sp24.sparsify24(A, backend="cusparselt") with pytest.raises(NotImplementedError, match="the same data type"): A @ B @requires_sp24_gemm @parametrize_backend @pytest.mark.parametrize("with_bias", [False, True]) def test_linear_dispatch_inference_mode(backend: str, with_bias: bool) -> None: B, ft_in, ft_out = 128, 256, 512 x = torch.randn([B, ft_in], device="cuda", dtype=torch.float16) weight = torch.randn([ft_out, ft_in], device="cuda", dtype=torch.float16) bias = ( torch.randn([ft_out], device="cuda", dtype=torch.float16) if with_bias else None ) w_sparse = sp24.sparsify24( weight, gradient="24dense", backend=backend, ) # NOTE: When in `inference_mode`, PyTorch no longer dispatches to `addmm`, but to `linear` # so we need to support that as well with torch.inference_mode(): # Does not support bias at the moment in CUTLASS backend if bias is not None and backend == sp24.BACKEND_CUTLASS: with pytest.raises(NotImplementedError): F.linear(x, w_sparse, bias) return out = F.linear(x, w_sparse, bias) out_ref = F.linear(x, w_sparse._sp24_to_dense(), bias) assert_allclose(out, out_ref, msg="output", **atol_rtol_kw[x.dtype]) @torch_compile_tests @cuda_only def test_sp24_meta() -> None: x = torch.randn([1024, 512], device="meta", dtype=torch.float16) x_s = sp24.sparsify24(x, backend="cusparselt") assert x_s.shape == x.shape x_s_t = x_s.t() assert x_s_t.shape == x.t().shape @torch_compile_tests @requires_sp24_gemm @parametrize_backend def test_sp24_compile(backend) -> None: x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True) e = torch.eye(x.shape[0], x.shape[0], device="cuda", dtype=torch.float16) def fn(x, e): y = sp24.sparsify24(x, backend=backend, gradient="24dense") y = y.t() return x @ y # Eager output = fn(x, e) output.backward(output) # Torch compile output = torch.compile(fn)(x, e) output.backward(output) class _TransformerFFN(nn.Module): def __init__( self, in_features: int, hidden_features=None, out_features=None, act_layer=nn.GELU, bias: bool = True, linear_cls=nn.Linear, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = linear_cls(in_features, hidden_features, bias=bias) self.act = act_layer() self.fc2 = linear_cls(hidden_features, out_features, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc1(x) x = self.act(x) x = self.fc2(x) return x @requires_sp24_gemm @torch_compile_tests @pytest.mark.skipif(not sp24._has_cusparseLt(), reason="requires cusparselt") def test_linearw24_block_compile() -> None: # TODO: Parametrize on `dtype` when torch.compile gets faster # currently takes ~5s per test dtype = torch.bfloat16 B, FT_IN, FT_HIDDEN = 31, 512, 2048 _workaround_cusparselt_internal_error() m = _TransformerFFN(FT_IN, FT_HIDDEN, linear_cls=LinearW24).to("cuda").to(dtype) m_c = _TransformerFFN(FT_IN, FT_HIDDEN, linear_cls=LinearW24).to("cuda").to(dtype) m_c.load_state_dict(m.state_dict()) m_c = cast(_TransformerFFN, torch.compile(m_c)) x, grad = [torch.randn([B, FT_IN], dtype=dtype, device="cuda") for _ in range(2)] x = x.requires_grad_() out = m(x) out.backward(grad) x_c = x.detach().requires_grad_() out_c = m_c(x_c) out_c.backward(grad) assert_allclose(out_c, out, "output", **atol_rtol_kw[dtype]) assert x_c.grad is not None and x.grad is not None assert_allclose(x_c.grad, x.grad, "output", **atol_rtol_kw[dtype]) for param_name, param_ref, param_c in [ ["fc1.weight", m.fc1.weight, m_c.fc1.weight], ["fc1.bias", m.fc1.bias, m_c.fc1.bias], ["fc2.weight", m.fc2.weight, m_c.fc2.weight], ["fc2.bias", m.fc2.bias, m_c.fc2.bias], ]: assert param_ref.grad is not None and param_c.grad is not None, param_name assert_allclose(param_c.grad, param_ref.grad, param_name, **atol_rtol_kw[dtype]) @requires_sp24 @pytest.mark.skipif(not sp24._has_cusparseLt(), reason="requires cusparselt") def test_sp24_ste(): x = torch.randn([512, 512], dtype=torch.float16, device="cuda", requires_grad=True) grad = torch.randn_like(x) spX = sp24.sparsify24(x, gradient=sp24.GRADIENT_STE) spX.backward(grad) assert_allclose(x.grad, grad, "grad") @requires_sp24_gemm @parametrize_dtype def test_sparsify24_ste(dtype): x = torch.randn([512, 512], dtype=dtype, device="cuda", requires_grad=True) y = torch.randn([512, 512], dtype=dtype, device="cuda", requires_grad=True) mul0 = 2.0 # (numbers that have an exact representation in f16) mul1 = 0.5 spX = sp24.sparsify24_ste(x, bw_mul0=mul0, bw_mul1=mul1) spX.backward(y) spYd = sp24.sparsify24_like(y, pattern=spX)._sp24_to_dense() ref = mul1 * (spYd) + mul0 * (y - spYd) assert_allclose(x.grad, ref, "grad")