Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import pytest | |
import torch | |
from xformers.components.attention import maybe_sparsify | |
from xformers.components.attention._sputnik_sparse import _dense_to_sparse | |
from xformers.components.attention.core import SparseCS, _create_random_sparsity | |
B = 2 | |
M = 16 # not a nice round number, on purpose | |
_devices_list = ["cpu", "cuda:0"] if torch.cuda.is_available() else ["cpu"] | |
_devices = [torch.device(d) for d in _devices_list] | |
def test_logical_and(device): | |
mask = _create_random_sparsity(torch.ones(B, M, M, dtype=torch.bool), 0.1) | |
mask_cs = SparseCS(mask, device) | |
# Check that we cannot & two sparse matrices (for now) | |
with pytest.raises(Exception): | |
_ = mask_cs & mask_cs | |
# Check that & ones returns the same values | |
mask_ones = mask_cs & torch.ones_like(mask, dtype=torch.bool, device=device) | |
assert torch.allclose(mask_cs.to_dense().long(), mask_ones.to_dense().long()) | |
# Check that & the inverse returns 0 all around | |
mask_not = ~mask.to(device) | |
assert (mask_cs & mask_not).values.numel() == 0 | |
def test_dense_sparse(seq, device): | |
# Check that we can .to_dense() without crashing | |
mask = torch.rand(seq, seq, device=device) > 0.1 | |
mask_cs = SparseCS(mask, device) | |
mask_back_forth = SparseCS(mask_cs.to_dense(), device) | |
assert torch.allclose(mask_cs.to_dense().long(), mask_back_forth.to_dense().long()) | |
def test_device(device): | |
mask = _create_random_sparsity( | |
torch.ones(B, M, M, dtype=torch.bool, device=device), 0.1 | |
) | |
assert mask.device.type == device.type | |
sparse_mask = maybe_sparsify(mask) | |
assert sparse_mask.device.type == device.type | |
def _baseline_dense_to_sparse(matrix): | |
import numpy as np | |
# Extract the nonzero values. | |
values = matrix.compress((matrix != 0).flatten()) | |
# Calculate the offset of each row. | |
mask = (matrix != 0).astype(np.int32) | |
row_offsets = np.concatenate(([0], np.cumsum(np.add.reduce(mask, axis=1))), axis=0) | |
# Create the row indices and sort them. | |
# note: use torch.argsort to make it compatible as sorting is not stable in PyTorch | |
row_indices = torch.argsort(-1 * torch.as_tensor(np.diff(row_offsets))).numpy() | |
# Extract the column indices for the nonzero values. | |
x = mask * (np.arange(matrix.shape[1]) + 1) | |
column_indices = x.compress((x != 0).flatten()) | |
column_indices = column_indices - 1 | |
# Cast the desired precision. | |
values = torch.as_tensor(values.astype(np.float32)) | |
row_indices, row_offsets, column_indices = [ | |
torch.as_tensor(x.astype(np.int32)) | |
for x in [row_indices, row_offsets, column_indices] | |
] | |
return values, row_indices, row_offsets, column_indices | |
def test_dense_to_sparse(seq, device): | |
matrix = torch.rand(seq, seq, device=device) | |
matrix[matrix > 0.9] = 0 | |
baseline_res = _baseline_dense_to_sparse(matrix.cpu().numpy()) | |
res = _dense_to_sparse(matrix, device=device) | |
_idx_to_name = ["values", "row_indices", "row_offsets", "column_indices"] | |
for idx, (bi, i) in enumerate(zip(baseline_res, res)): | |
if idx != 1: | |
# row_indices is the result of an argsort, which is not stable | |
# for same number of elements | |
assert torch.allclose(bi.to(device), i), f"error in {_idx_to_name[idx]}" | |
assert bi.dtype == i.dtype | |
assert i.device == device | |