# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import pytest import torch from xformers.components.attention import maybe_sparsify from xformers.components.attention._sputnik_sparse import _dense_to_sparse from xformers.components.attention.core import SparseCS, _create_random_sparsity B = 2 M = 16 # not a nice round number, on purpose _devices_list = ["cpu", "cuda:0"] if torch.cuda.is_available() else ["cpu"] _devices = [torch.device(d) for d in _devices_list] @pytest.mark.parametrize("device", _devices) def test_logical_and(device): mask = _create_random_sparsity(torch.ones(B, M, M, dtype=torch.bool), 0.1) mask_cs = SparseCS(mask, device) # Check that we cannot & two sparse matrices (for now) with pytest.raises(Exception): _ = mask_cs & mask_cs # Check that & ones returns the same values mask_ones = mask_cs & torch.ones_like(mask, dtype=torch.bool, device=device) assert torch.allclose(mask_cs.to_dense().long(), mask_ones.to_dense().long()) # Check that & the inverse returns 0 all around mask_not = ~mask.to(device) assert (mask_cs & mask_not).values.numel() == 0 @pytest.mark.parametrize("device", _devices) @pytest.mark.parametrize("seq", [12, 32, 128]) def test_dense_sparse(seq, device): # Check that we can .to_dense() without crashing mask = torch.rand(seq, seq, device=device) > 0.1 mask_cs = SparseCS(mask, device) mask_back_forth = SparseCS(mask_cs.to_dense(), device) assert torch.allclose(mask_cs.to_dense().long(), mask_back_forth.to_dense().long()) @pytest.mark.parametrize("device", _devices) def test_device(device): mask = _create_random_sparsity( torch.ones(B, M, M, dtype=torch.bool, device=device), 0.1 ) assert mask.device.type == device.type sparse_mask = maybe_sparsify(mask) assert sparse_mask.device.type == device.type def _baseline_dense_to_sparse(matrix): import numpy as np # Extract the nonzero values. values = matrix.compress((matrix != 0).flatten()) # Calculate the offset of each row. mask = (matrix != 0).astype(np.int32) row_offsets = np.concatenate(([0], np.cumsum(np.add.reduce(mask, axis=1))), axis=0) # Create the row indices and sort them. # note: use torch.argsort to make it compatible as sorting is not stable in PyTorch row_indices = torch.argsort(-1 * torch.as_tensor(np.diff(row_offsets))).numpy() # Extract the column indices for the nonzero values. x = mask * (np.arange(matrix.shape[1]) + 1) column_indices = x.compress((x != 0).flatten()) column_indices = column_indices - 1 # Cast the desired precision. values = torch.as_tensor(values.astype(np.float32)) row_indices, row_offsets, column_indices = [ torch.as_tensor(x.astype(np.int32)) for x in [row_indices, row_offsets, column_indices] ] return values, row_indices, row_offsets, column_indices @pytest.mark.parametrize("device", _devices) @pytest.mark.parametrize("seq", [12, 32, 128]) def test_dense_to_sparse(seq, device): matrix = torch.rand(seq, seq, device=device) matrix[matrix > 0.9] = 0 baseline_res = _baseline_dense_to_sparse(matrix.cpu().numpy()) res = _dense_to_sparse(matrix, device=device) _idx_to_name = ["values", "row_indices", "row_offsets", "column_indices"] for idx, (bi, i) in enumerate(zip(baseline_res, res)): if idx != 1: # row_indices is the result of an argsort, which is not stable # for same number of elements assert torch.allclose(bi.to(device), i), f"error in {_idx_to_name[idx]}" assert bi.dtype == i.dtype assert i.device == device