Flexstorydiff / xformers /tests /test_sparsecs.py
FlexTheAi's picture
Upload folder using huggingface_hub
e202b16 verified
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from xformers.components.attention import maybe_sparsify
from xformers.components.attention._sputnik_sparse import _dense_to_sparse
from xformers.components.attention.core import SparseCS, _create_random_sparsity
B = 2
M = 16 # not a nice round number, on purpose
_devices_list = ["cpu", "cuda:0"] if torch.cuda.is_available() else ["cpu"]
_devices = [torch.device(d) for d in _devices_list]
@pytest.mark.parametrize("device", _devices)
def test_logical_and(device):
mask = _create_random_sparsity(torch.ones(B, M, M, dtype=torch.bool), 0.1)
mask_cs = SparseCS(mask, device)
# Check that we cannot & two sparse matrices (for now)
with pytest.raises(Exception):
_ = mask_cs & mask_cs
# Check that & ones returns the same values
mask_ones = mask_cs & torch.ones_like(mask, dtype=torch.bool, device=device)
assert torch.allclose(mask_cs.to_dense().long(), mask_ones.to_dense().long())
# Check that & the inverse returns 0 all around
mask_not = ~mask.to(device)
assert (mask_cs & mask_not).values.numel() == 0
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("seq", [12, 32, 128])
def test_dense_sparse(seq, device):
# Check that we can .to_dense() without crashing
mask = torch.rand(seq, seq, device=device) > 0.1
mask_cs = SparseCS(mask, device)
mask_back_forth = SparseCS(mask_cs.to_dense(), device)
assert torch.allclose(mask_cs.to_dense().long(), mask_back_forth.to_dense().long())
@pytest.mark.parametrize("device", _devices)
def test_device(device):
mask = _create_random_sparsity(
torch.ones(B, M, M, dtype=torch.bool, device=device), 0.1
)
assert mask.device.type == device.type
sparse_mask = maybe_sparsify(mask)
assert sparse_mask.device.type == device.type
def _baseline_dense_to_sparse(matrix):
import numpy as np
# Extract the nonzero values.
values = matrix.compress((matrix != 0).flatten())
# Calculate the offset of each row.
mask = (matrix != 0).astype(np.int32)
row_offsets = np.concatenate(([0], np.cumsum(np.add.reduce(mask, axis=1))), axis=0)
# Create the row indices and sort them.
# note: use torch.argsort to make it compatible as sorting is not stable in PyTorch
row_indices = torch.argsort(-1 * torch.as_tensor(np.diff(row_offsets))).numpy()
# Extract the column indices for the nonzero values.
x = mask * (np.arange(matrix.shape[1]) + 1)
column_indices = x.compress((x != 0).flatten())
column_indices = column_indices - 1
# Cast the desired precision.
values = torch.as_tensor(values.astype(np.float32))
row_indices, row_offsets, column_indices = [
torch.as_tensor(x.astype(np.int32))
for x in [row_indices, row_offsets, column_indices]
]
return values, row_indices, row_offsets, column_indices
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("seq", [12, 32, 128])
def test_dense_to_sparse(seq, device):
matrix = torch.rand(seq, seq, device=device)
matrix[matrix > 0.9] = 0
baseline_res = _baseline_dense_to_sparse(matrix.cpu().numpy())
res = _dense_to_sparse(matrix, device=device)
_idx_to_name = ["values", "row_indices", "row_offsets", "column_indices"]
for idx, (bi, i) in enumerate(zip(baseline_res, res)):
if idx != 1:
# row_indices is the result of an argsort, which is not stable
# for same number of elements
assert torch.allclose(bi.to(device), i), f"error in {_idx_to_name[idx]}"
assert bi.dtype == i.dtype
assert i.device == device