Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import pytest | |
import torch | |
# needed to register custom ops | |
import xformers # noqa: F401 | |
import xformers.components.attention.core | |
from xformers.components.attention._sputnik_sparse import _csr_to_coo | |
from xformers.components.attention.core import ( | |
_broadcast_batch, | |
_create_random_sparsity, | |
_sparse_bmm, | |
) | |
cuda_only = pytest.mark.skipif( | |
not torch.cuda.is_available() or not torch.version.cuda, reason="requires CUDA" | |
) | |
_devices = ( | |
["cpu", "cuda"] if torch.cuda.is_available() and torch.version.cuda else ["cpu"] | |
) | |
def _baseline_matmul_with_sparse_mask( | |
a: torch.Tensor, b: torch.Tensor, mask: torch.Tensor | |
) -> torch.Tensor: | |
assert a.ndim == b.ndim | |
assert mask.ndim == a.ndim | |
assert a.shape[-1] == b.shape[-2] | |
assert a.shape[-2] == mask.shape[-2], f"{a.shape}, {mask.shape}" | |
assert b.shape[-1] == mask.shape[-1], f"{b.shape}, {mask.shape}" | |
assert a.shape[:-2] == b.shape[:-2], f"{a.shape}, {b.shape}" | |
assert a.shape[:-2] == mask.shape[:-2], f"{a.shape}, {mask.shape}" | |
idxs = mask.indices().unbind() | |
b = b.transpose(-2, -1) | |
# compute matmul for elements within the mask | |
val = (a[idxs[:-2] + (idxs[-2], slice(None))] * b[idxs[:-2] + (idxs[-1], slice(None))]).sum(-1) # type: ignore | |
out_shape = a.shape[:-1] + (b.shape[-2],) | |
res = torch.sparse_coo_tensor(torch.stack(idxs), val, out_shape) | |
return res | |
def _baseline_matmul_with_dense_mask( | |
a: torch.Tensor, b: torch.Tensor, mask: torch.Tensor | |
) -> torch.Tensor: | |
res = a @ b | |
res[~mask] = float("-inf") | |
return res | |
def _baseline_sparse_bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | |
# need to use torch.sparse.mm to get gradients wrt sparse matrix a | |
# TODO implement this in C++ / CUDA as this is slow! | |
out = [] | |
for ai, bi in zip(a, b): | |
out.append(torch.sparse.mm(ai, bi)) | |
return torch.stack(out, dim=0) | |
def test_matmul_with_mask(device, contiguous, is_sparse): | |
B, L, K = 8, 30, 32 | |
prob = 0.5 | |
a = torch.rand(B, L, K, device=device) | |
b = torch.rand(B, K, L, device=device) | |
if not contiguous: | |
a = a.transpose(-2, -1).contiguous().transpose(-2, -1) | |
b = b.transpose(-2, -1).contiguous().transpose(-2, -1) | |
mask = torch.rand(B, L, L, device=device) > prob | |
fn = torch.ops.xformers.matmul_with_mask | |
fn_gt = _baseline_matmul_with_dense_mask | |
if is_sparse: | |
mask = mask.to_sparse() | |
fn_gt = _baseline_matmul_with_sparse_mask | |
res = fn(a, b, mask) | |
res_gt = fn_gt(a, b, mask) | |
if is_sparse: | |
res = res.to_dense() | |
res_gt = res_gt.to_dense() | |
assert res.dtype == res_gt.dtype | |
assert torch.allclose(res, res_gt) | |
def test_matmul_with_mask_backward(device, contiguous, is_sparse): | |
if device == "cuda" and is_sparse is False: | |
# Skip test for now due to bug in torch 1.8 | |
# See https://github.com/pytorch/pytorch/issues/54975 | |
# Broken CUDA / torch 1.8 combination, awaiting an update | |
return | |
B, L, K = 8, 10, 16 | |
prob = 0.5 | |
a = torch.rand(B, L, K, device=device, requires_grad=True) | |
b = torch.rand(B, K, L, device=device, requires_grad=True) | |
if not contiguous: | |
a = a.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_() | |
b = b.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_() | |
mask = torch.rand(B, L, L, device=device) > prob | |
fn = torch.ops.xformers.matmul_with_mask | |
fn_gt = _baseline_matmul_with_dense_mask | |
if is_sparse: | |
mask = mask.to_sparse() | |
fn_gt = _baseline_matmul_with_sparse_mask | |
def compute_grads(f): | |
out = f(a, b, mask) | |
if is_sparse: | |
out = out.to_dense() | |
out.sum().backward() | |
compute_grads(fn) | |
grad_a = a.grad.clone() | |
grad_b = b.grad.clone() | |
a.grad = None | |
b.grad = None | |
compute_grads(fn_gt) | |
assert torch.allclose(grad_a, a.grad) | |
assert torch.allclose(grad_b, b.grad) | |
def test_sddmm_sputnik(device): | |
B, L, M, K = 8, 30, 16, 32 | |
prob = 0.5 | |
a = torch.rand(B, L, K, device=device) | |
b = torch.rand(B, M, K, device=device).transpose(-2, -1) | |
mask = _create_random_sparsity( | |
torch.ones(B, L, M, dtype=torch.bool, device=device), prob | |
) | |
mask_csr = xformers.components.attention.core.SparseCS(mask, device) | |
fn = xformers.components.attention.core._matmul_with_mask | |
mask = mask.to_sparse() | |
res = fn(a, b, mask_csr) | |
res_gt = fn(a, b, mask) | |
res = res.to_dense() | |
res_gt = res_gt.to_dense() | |
assert res.dtype == res_gt.dtype | |
assert torch.allclose(res, res_gt) | |
def test_sddmm_csr(L, M, K, prob): | |
device = torch.device("cuda") | |
# TODO add more checks for different nnz | |
B = 8 | |
a = torch.rand(B, L, K, device=device) | |
b = torch.rand(B, M, K, device=device) | |
mask = _create_random_sparsity( | |
torch.ones(B, L, M, dtype=torch.bool, device=device), prob | |
) | |
mask_csr = xformers.components.attention.core.SparseCS(mask, device) | |
row_indices = mask_csr.row_indices | |
row_offsets = mask_csr.row_offsets | |
column_indices = mask_csr.column_indices | |
fn = torch.ops.xformers.csr_sddmm | |
fn_gt = torch.ops.xformers.sddmm_sputnik | |
res = fn(a, b, row_indices, row_offsets, column_indices) | |
res_gt = fn_gt(a, b, row_indices, row_offsets, column_indices) | |
assert res.dtype == res_gt.dtype | |
assert torch.allclose(res, res_gt, atol=1e-6) | |
def test_sddmm_csr_per_nnz(nnz): | |
device = torch.device("cuda") | |
B = 8 | |
L, M, K = 1024, 1024, 32 | |
a = torch.rand(B, L, K, device=device) | |
b = torch.rand(B, M, K, device=device) | |
mask = torch.zeros(L, M, dtype=torch.bool, device=device) | |
mask.view(-1)[: nnz - 1] = True | |
mask[-1, -1] = True | |
mask_csr = xformers.components.attention.core.SparseCS(mask, device) | |
row_indices = mask_csr.row_indices | |
row_offsets = mask_csr.row_offsets | |
column_indices = mask_csr.column_indices | |
fn = torch.ops.xformers.csr_sddmm | |
fn_gt = torch.ops.xformers.sddmm_sputnik | |
res = fn(a, b, row_indices, row_offsets, column_indices) | |
res_gt = fn_gt(a, b, row_indices, row_offsets, column_indices) | |
assert res.dtype == res_gt.dtype | |
assert torch.allclose(res, res_gt, atol=1e-6) | |
def test_sddmm_coo(L, M, K, prob): | |
device = torch.device("cuda") | |
# TODO add more checks for different nnz | |
B = 8 | |
a = torch.rand(B, L, K, device=device) | |
b = torch.rand(B, M, K, device=device) | |
mask = _create_random_sparsity( | |
torch.ones(B, L, M, dtype=torch.bool, device=device), prob | |
) | |
mask_csr = xformers.components.attention.core.SparseCS(mask, device) | |
row_indices = mask_csr.row_indices | |
row_offsets = mask_csr.row_offsets | |
column_indices = mask_csr.column_indices | |
fn = torch.ops.xformers.coo_sddmm | |
fn_gt = torch.ops.xformers.sddmm_sputnik | |
# convert from csr to coo | |
row_coo, _ = _csr_to_coo(L, M, row_offsets, column_indices) | |
res = fn(a, b, row_indices, row_coo, column_indices) | |
res_gt = fn_gt(a, b, row_indices, row_offsets, column_indices) | |
assert res.dtype == res_gt.dtype | |
assert torch.allclose(res, res_gt, atol=1e-6) | |
def test_sddmm_sputnik_backward(device): | |
contiguous = True | |
B, L, M, K = 8, 10, 16, 32 | |
prob = 0.5 | |
a = torch.rand(B, L, K, device=device, requires_grad=True) | |
b = torch.rand(B, M, K, device=device).transpose(-2, -1).requires_grad_(True) | |
if not contiguous: | |
a = a.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_() | |
b = b.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_() | |
mask = _create_random_sparsity( | |
torch.ones(B, L, M, dtype=torch.bool, device=device), prob | |
) | |
mask_csr = xformers.components.attention.core.SparseCS(mask, device) | |
fn = xformers.components.attention.core._matmul_with_mask | |
mask = mask.to_sparse() | |
out_csr = fn(a, b, mask_csr) | |
out_csr.values.sum().backward() | |
grad_a = a.grad.clone() | |
grad_b = b.grad.clone() | |
a.grad = None | |
b.grad = None | |
# fn(a[None], b[None], mask).coalesce().values().sum().backward() # TODO check why this fails | |
fn(a, b, mask).to_dense().sum().backward() | |
assert torch.allclose(grad_a, a.grad, atol=1e-7) | |
assert torch.allclose(grad_b, b.grad, atol=1e-7) | |
def test_sparse_softmax_sputnik(device): | |
B, L = 8, 30 | |
prob = 0.5 | |
a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob) | |
a_csr = xformers.components.attention.core.SparseCS(a, device) | |
fn = xformers.components.attention.core._softmax | |
a = a.to_sparse() | |
res = fn(a_csr) | |
res_gt = fn(a) | |
res = res.to_dense() | |
res_gt = res_gt.to_dense() | |
assert res.dtype == res_gt.dtype | |
assert torch.allclose(res, res_gt) | |
def test_sparse_softmax_sputnik_backward(device): | |
B, L = 8, 30 | |
prob = 0.5 | |
a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob) | |
a_csr = xformers.components.attention.core.SparseCS(a, device) | |
fn = xformers.components.attention.core._softmax | |
a = a.to_sparse() | |
a_csr.values.requires_grad_(True) | |
fn(a_csr).values.sum().backward() | |
grad_a = a_csr.values.grad.clone() | |
a.requires_grad_(True) | |
fn(a).coalesce().values().sum().backward() | |
assert torch.allclose( | |
grad_a, a.grad.coalesce().values().reshape_as(grad_a), atol=1e-7 | |
) | |
def test_spmm_sputnik(device): | |
B, L, K = 8, 30, 32 | |
prob = 0.5 | |
a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob) | |
b = torch.rand(B, L, K, device=device) | |
a_csr = xformers.components.attention.core.SparseCS(a, device) | |
fn = xformers.components.attention.core.bmm | |
a = a.to_sparse() | |
res = fn(a_csr, b) | |
res_gt = fn(a, b) | |
res = res | |
res_gt = res_gt | |
assert res.dtype == res_gt.dtype | |
assert torch.allclose(res, res_gt) | |
def test_spmm_sputnik_backward(device): | |
B, M, L, K = 8, 16, 30, 32 | |
prob = 0.5 | |
a = _create_random_sparsity(torch.rand(B, M, L, device=device), prob) | |
b = torch.rand(B, L, K, device=device) | |
b.requires_grad_(True) | |
a_csr = xformers.components.attention.core.SparseCS(a, device) | |
fn = xformers.components.attention.core.bmm | |
a = a.to_sparse() | |
a.requires_grad_(True) | |
a_csr.values.requires_grad_(True) | |
fn(a_csr, b).sum().backward() | |
grad_a = a_csr.values.grad.clone() | |
grad_b = b.grad.clone() | |
b.grad = None | |
fn(a, b).sum().backward() | |
assert torch.allclose( | |
grad_a, a.grad.coalesce().values().reshape_as(grad_a), atol=1e-7 | |
) | |
assert torch.allclose(grad_b, b.grad, atol=1e-7) | |
def test_csr_transpose(): | |
B, L, K = 8, 30, 40 | |
prob = 0.5 | |
device = torch.device("cuda") | |
a = _create_random_sparsity(torch.rand(B, L, K, device=device), prob) | |
a_csr = xformers.components.attention.core.SparseCS(a, device) | |
res = a_csr.transpose() | |
res2 = res.transpose() | |
assert torch.allclose(res.to_dense(), a.transpose(-2, -1)) | |
assert torch.allclose(res2.to_dense(), a) | |
# cover > 0.995 | |
# cover > 64 | |
def test_sparse_bmm(device, contiguous, prob, N): | |
B, M = 8, 64 | |
a = torch.rand(B, M, N, device=device) | |
a[a < prob] = 0 | |
a = a.to_sparse() | |
b = torch.rand(B, N, M, device=device) | |
if not contiguous: | |
a = a + a | |
b = b.transpose(-2, -1).contiguous().transpose(-2, -1) | |
res = _sparse_bmm(a, b) | |
res_gt = _baseline_sparse_bmm(a, b) | |
assert torch.allclose(res, res_gt) | |
def test_sparse_bmm_backward(device, contiguous): | |
if device == "cuda": | |
# Skip test for now due to bug in torch 1.8 | |
# See https://github.com/pytorch/pytorch/issues/54975 | |
# Broken CUDA / torch 1.8 combination, awaiting an update | |
return | |
B, L, K = 8, 10, 16 | |
prob = 0.5 | |
a = torch.rand(B, L, K, device=device) | |
a[a < prob] = 0 | |
a = a.to_sparse() | |
b = torch.rand(B, K, L, device=device, requires_grad=True) | |
if not contiguous: | |
a = a + a | |
b = b.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_() | |
a.requires_grad_(True) | |
def compute_grads(f): | |
out = f(a, b) | |
out.sum().backward() | |
compute_grads(_sparse_bmm) | |
grad_a = a.grad.clone().coalesce() | |
grad_b = b.grad.clone() | |
a.grad = None | |
b.grad = None | |
compute_grads(_baseline_sparse_bmm) | |
new_grad_a = a.grad.coalesce() | |
assert torch.allclose(grad_a.indices(), new_grad_a.indices()) | |
assert torch.allclose(grad_a.values(), new_grad_a.values()) | |
assert torch.allclose(grad_b, b.grad) | |
def test_sparse_coo_broadcast(device): | |
B, L, K = 8, 10, 16 | |
prob = 0.5 | |
a = torch.rand(L, K, device=device) | |
a[a < prob] = 0 | |
a_sparse = a.to_sparse() | |
res = _broadcast_batch(a_sparse, B) | |
res_gt = a[None, :, :].expand(B, L, K) | |
assert torch.allclose(res.to_dense(), res_gt) | |