# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import pytest import torch # needed to register custom ops import xformers # noqa: F401 from xformers.ops import masked_matmul from xformers.sparse import BlockSparseTensor, SparseCSRTensor from .utils import disable_tf32 cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") _devices = ( ["cpu", "cuda:0"] if torch.cuda.is_available() and torch.version.cuda else ["cpu"] ) _tensor_types = [BlockSparseTensor, SparseCSRTensor] def _create_blocksparse_tensor( device, block_size=32, Z=8, C=2, H=64, W=64, dtype=torch.float32 ): layout = torch.randint(2, (C, H // block_size, W // block_size), device=device) layout[:, :, 0] = 1 layout[:, 0, :] = 1 values = torch.randn(Z, layout.sum(), block_size, block_size, device=device).to( dtype ) return BlockSparseTensor(values, layout) def _create_csr_tensor(device, dtype, shape, sparsity, divisible_by=4): matrix = torch.rand(shape, dtype=torch.float32, device=device).to(dtype) assert matrix.ndim == 3 keep = torch.rand_like(matrix[0], dtype=torch.float32) > sparsity nonzero = torch.nonzero(keep) nnz = nonzero.shape[0] # NOTE: need to make it a multiple of 4 for sputnik nonzero = nonzero[: (nnz - nnz % divisible_by)] i, j = nonzero.unbind(1) output = torch.zeros_like(matrix) bdim = torch.arange(matrix.shape[0], device=matrix.device)[:, None] output[bdim, i, j] = matrix[bdim, i, j] return SparseCSRTensor.from_dense(output) def _create_tensor(tensor_type, device, dtype, shape, sparsity): if tensor_type == BlockSparseTensor: block_size = 16 return _create_blocksparse_tensor( device=device, dtype=dtype, block_size=block_size ) elif tensor_type == SparseCSRTensor: return _create_csr_tensor( device=device, dtype=dtype, shape=shape, sparsity=sparsity ) def _seed(): torch.random.manual_seed(42) torch.cuda.manual_seed_all(42) def _get_dtype_atol(tensor_type, device: str): _seed() if tensor_type == BlockSparseTensor and "cuda" in device: # Upstream GPU blocksparse (Triton op) uses TF32 by default for all internal computations # TF32 has the precision of fp16 but the range of fp32 # See https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/ torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # type: ignore return torch.float32, 1e-1 # Force pytorch to keep its computations as float32 (will default to tf32 with recent cuda and ampere+ GPU) torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False # type: ignore return torch.float32, 1e-5 @pytest.mark.parametrize("device", _devices) @pytest.mark.parametrize("func", [torch.add, torch.mul]) def test_sparse_binary_ops(func, device): # TODO: add for BlockSparseTensor as well N, H, W = 8, 64, 64 sparsity = 0.5 shape = (N, H, W) a_sparse = _create_tensor( SparseCSRTensor, device, dtype=torch.float32, shape=shape, sparsity=sparsity ) a = a_sparse.to_dense() b = a b_sparse = a_sparse res = func(a_sparse, b_sparse).to_dense() res_gt = func(a, b) assert torch.allclose(res, res_gt) @disable_tf32 @pytest.mark.parametrize("tensor_type", _tensor_types) @pytest.mark.parametrize("device", _devices) def test_masked_matmul(tensor_type, device): N, C, H, W, L = 8, 2, 64, 64, 32 sparsity = 0.7 dtype, atol = _get_dtype_atol(tensor_type, device) shape0 = (N, C, H, W) shape1 = (N, C, H, L) shape2 = (N, C, W, L) if tensor_type != BlockSparseTensor: shape0 = shape0[1:] shape1 = shape1[1:] shape2 = shape2[1:] mask_sparse = _create_tensor( tensor_type, device, dtype=torch.bool, shape=shape0, sparsity=sparsity ) mask = mask_sparse.to_dense() a = torch.randn(shape1, device=device, dtype=dtype) b = torch.randn(shape2, device=device, dtype=dtype) aa = a.clone() bb = b.clone() a.requires_grad_(True) b.requires_grad_(True) aa.requires_grad_(True) bb.requires_grad_(True) bt = b.transpose(-2, -1) bbt = bb.transpose(-2, -1) res_gt = masked_matmul(a, bt, mask) res = masked_matmul(aa, bbt, mask_sparse) res_dense = res.to_dense() res_dense = torch.where(mask, res_dense, torch.full_like(res_dense, float("-inf"))) assert res.dtype == res_gt.dtype assert torch.allclose(res_dense, res_gt, atol=atol) # try to workaround non-contiguous issues with triton for now res_gt.backward(torch.ones_like(res_gt)) res.values().backward(torch.ones_like(res.values())) assert torch.allclose(a.grad, aa.grad, atol=atol) assert torch.allclose(b.grad, bb.grad, atol=atol) @disable_tf32 @pytest.mark.parametrize("tensor_type", _tensor_types) @pytest.mark.parametrize("device", _devices) def test_bmm(tensor_type, device): N, C, H, W, L = 8, 2, 64, 64, 32 dtype, atol = _get_dtype_atol(tensor_type, device) sparsity = 0.8 shape0 = (N, C, H, W) shape1 = (N, C, W, L) if tensor_type != BlockSparseTensor: shape0 = shape0[1:] shape1 = shape1[1:] a_sparse = _create_tensor( tensor_type, device, dtype=dtype, shape=shape0, sparsity=sparsity ) a = a_sparse.to_dense() mask = a != 0 a_sparse.requires_grad_(True) a.requires_grad_(True) b = torch.randn(shape1, device=device, dtype=dtype) b2 = b.clone() b.requires_grad_(True) b2.requires_grad_(True) res_gt = a @ b res = a_sparse @ b2 assert res.dtype == res_gt.dtype assert torch.allclose( res, res_gt, atol=atol ), f"{torch.max(torch.abs(res-res_gt))} - tolerance: {atol}" res_gt.sum().backward() res.sum().backward() a_grad = a.grad.clone().detach() a_grad[~mask] = 0 assert torch.allclose(b.grad, b2.grad, atol=atol) assert torch.allclose( a_grad, a_sparse.grad.to_dense(), atol=atol ), f"{torch.max(torch.abs(a_grad-a_sparse.grad.to_dense()))}" @disable_tf32 @pytest.mark.parametrize("tensor_type", _tensor_types) @pytest.mark.parametrize("device", _devices) def test_sparse_softmax(tensor_type, device): N, C, H, W = 8, 2, 64, 64 dtype, atol = _get_dtype_atol(tensor_type, device) sparsity = 0.8 shape0 = (N, C, H, W) if tensor_type != BlockSparseTensor: shape0 = shape0[1:] a_sparse = _create_tensor( tensor_type, device, dtype=dtype, shape=shape0, sparsity=sparsity ) a = a_sparse.to_dense() mask = a != 0 a[~mask] = float("-inf") a_sparse.requires_grad_(True) a.requires_grad_(True) res_gt = torch.softmax(a, dim=-1) res_sparse = torch.softmax(a_sparse, dim=-1) res = res_sparse.to_dense() assert res.dtype == res_gt.dtype assert torch.allclose( res, res_gt, atol=atol ), f"{torch.max(torch.abs(res- res_gt))}" # WARNING: gradients are modified in-place! res_sparse.values().backward(torch.ones_like(res_sparse.values())) res_gt.backward(torch.ones_like(res_gt)) a_grad = a.grad.clone() a_grad[~mask] = 0 assert torch.allclose( a_grad, a_sparse.grad.to_dense(), atol=atol ), f"{torch.max(torch.abs(a_grad- a_sparse.grad.to_dense()))}" @pytest.mark.parametrize("tensor_type", _tensor_types) @pytest.mark.parametrize("device", _devices) def test_deepcopy(tensor_type, device): import copy N, C, H, W = 8, 2, 64, 64 dtype = torch.float32 sparsity = 0.8 shape0 = (N, C, H, W) if tensor_type != BlockSparseTensor: shape0 = shape0[1:] a_sparse = _create_tensor( tensor_type, device, dtype=dtype, shape=shape0, sparsity=sparsity ) b_sparse = copy.deepcopy(a_sparse) assert torch.equal(a_sparse, b_sparse) @pytest.mark.parametrize("tensor_type", _tensor_types) @pytest.mark.parametrize("device", _devices) def test_module_buffer(tensor_type, device): N, C, H, W = 8, 2, 64, 64 dtype = torch.float32 sparsity = 0.8 shape0 = (N, C, H, W) if tensor_type != BlockSparseTensor: shape0 = shape0[1:] a_sparse = _create_tensor( tensor_type, device, dtype=dtype, shape=shape0, sparsity=sparsity ) b_sparse = _create_tensor( tensor_type, device, dtype=dtype, shape=shape0, sparsity=sparsity ) module = torch.nn.Module() # test that register_buffer works module.register_buffer("a_sparse", a_sparse) assert module.a_sparse is a_sparse module.to(device) assert module.a_sparse.device == torch.device(device) state_dict = module.state_dict() assert "a_sparse" in state_dict assert torch.equal(a_sparse.to(device), state_dict["a_sparse"]) module.load_state_dict(state_dict) module.load_state_dict({"a_sparse": b_sparse}) assert torch.equal(module.a_sparse, b_sparse.to(device))