Flexstorydiff / xformers /tests /test_custom_ops.py
FlexTheAi's picture
Upload folder using huggingface_hub
e202b16 verified
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
# needed to register custom ops
import xformers # noqa: F401
import xformers.components.attention.core
from xformers.components.attention._sputnik_sparse import _csr_to_coo
from xformers.components.attention.core import (
_broadcast_batch,
_create_random_sparsity,
_sparse_bmm,
)
cuda_only = pytest.mark.skipif(
not torch.cuda.is_available() or not torch.version.cuda, reason="requires CUDA"
)
_devices = (
["cpu", "cuda"] if torch.cuda.is_available() and torch.version.cuda else ["cpu"]
)
def _baseline_matmul_with_sparse_mask(
a: torch.Tensor, b: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
assert a.ndim == b.ndim
assert mask.ndim == a.ndim
assert a.shape[-1] == b.shape[-2]
assert a.shape[-2] == mask.shape[-2], f"{a.shape}, {mask.shape}"
assert b.shape[-1] == mask.shape[-1], f"{b.shape}, {mask.shape}"
assert a.shape[:-2] == b.shape[:-2], f"{a.shape}, {b.shape}"
assert a.shape[:-2] == mask.shape[:-2], f"{a.shape}, {mask.shape}"
idxs = mask.indices().unbind()
b = b.transpose(-2, -1)
# compute matmul for elements within the mask
val = (a[idxs[:-2] + (idxs[-2], slice(None))] * b[idxs[:-2] + (idxs[-1], slice(None))]).sum(-1) # type: ignore
out_shape = a.shape[:-1] + (b.shape[-2],)
res = torch.sparse_coo_tensor(torch.stack(idxs), val, out_shape)
return res
def _baseline_matmul_with_dense_mask(
a: torch.Tensor, b: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
res = a @ b
res[~mask] = float("-inf")
return res
def _baseline_sparse_bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
# need to use torch.sparse.mm to get gradients wrt sparse matrix a
# TODO implement this in C++ / CUDA as this is slow!
out = []
for ai, bi in zip(a, b):
out.append(torch.sparse.mm(ai, bi))
return torch.stack(out, dim=0)
@pytest.mark.parametrize("is_sparse", [True, False])
@pytest.mark.parametrize("contiguous", [True, False])
@pytest.mark.parametrize("device", _devices)
def test_matmul_with_mask(device, contiguous, is_sparse):
B, L, K = 8, 30, 32
prob = 0.5
a = torch.rand(B, L, K, device=device)
b = torch.rand(B, K, L, device=device)
if not contiguous:
a = a.transpose(-2, -1).contiguous().transpose(-2, -1)
b = b.transpose(-2, -1).contiguous().transpose(-2, -1)
mask = torch.rand(B, L, L, device=device) > prob
fn = torch.ops.xformers.matmul_with_mask
fn_gt = _baseline_matmul_with_dense_mask
if is_sparse:
mask = mask.to_sparse()
fn_gt = _baseline_matmul_with_sparse_mask
res = fn(a, b, mask)
res_gt = fn_gt(a, b, mask)
if is_sparse:
res = res.to_dense()
res_gt = res_gt.to_dense()
assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt)
@pytest.mark.parametrize("is_sparse", [True, False])
@pytest.mark.parametrize("contiguous", [True, False])
@pytest.mark.parametrize("device", _devices)
def test_matmul_with_mask_backward(device, contiguous, is_sparse):
if device == "cuda" and is_sparse is False:
# Skip test for now due to bug in torch 1.8
# See https://github.com/pytorch/pytorch/issues/54975
# Broken CUDA / torch 1.8 combination, awaiting an update
return
B, L, K = 8, 10, 16
prob = 0.5
a = torch.rand(B, L, K, device=device, requires_grad=True)
b = torch.rand(B, K, L, device=device, requires_grad=True)
if not contiguous:
a = a.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_()
b = b.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_()
mask = torch.rand(B, L, L, device=device) > prob
fn = torch.ops.xformers.matmul_with_mask
fn_gt = _baseline_matmul_with_dense_mask
if is_sparse:
mask = mask.to_sparse()
fn_gt = _baseline_matmul_with_sparse_mask
def compute_grads(f):
out = f(a, b, mask)
if is_sparse:
out = out.to_dense()
out.sum().backward()
compute_grads(fn)
grad_a = a.grad.clone()
grad_b = b.grad.clone()
a.grad = None
b.grad = None
compute_grads(fn_gt)
assert torch.allclose(grad_a, a.grad)
assert torch.allclose(grad_b, b.grad)
@pytest.mark.parametrize("device", _devices)
def test_sddmm_sputnik(device):
B, L, M, K = 8, 30, 16, 32
prob = 0.5
a = torch.rand(B, L, K, device=device)
b = torch.rand(B, M, K, device=device).transpose(-2, -1)
mask = _create_random_sparsity(
torch.ones(B, L, M, dtype=torch.bool, device=device), prob
)
mask_csr = xformers.components.attention.core.SparseCS(mask, device)
fn = xformers.components.attention.core._matmul_with_mask
mask = mask.to_sparse()
res = fn(a, b, mask_csr)
res_gt = fn(a, b, mask)
res = res.to_dense()
res_gt = res_gt.to_dense()
assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt)
@cuda_only
@pytest.mark.parametrize("prob", [0.5, 1])
@pytest.mark.parametrize("K", [32, 17])
@pytest.mark.parametrize("M", [30, 17])
@pytest.mark.parametrize("L", [30, 17])
def test_sddmm_csr(L, M, K, prob):
device = torch.device("cuda")
# TODO add more checks for different nnz
B = 8
a = torch.rand(B, L, K, device=device)
b = torch.rand(B, M, K, device=device)
mask = _create_random_sparsity(
torch.ones(B, L, M, dtype=torch.bool, device=device), prob
)
mask_csr = xformers.components.attention.core.SparseCS(mask, device)
row_indices = mask_csr.row_indices
row_offsets = mask_csr.row_offsets
column_indices = mask_csr.column_indices
fn = torch.ops.xformers.csr_sddmm
fn_gt = torch.ops.xformers.sddmm_sputnik
res = fn(a, b, row_indices, row_offsets, column_indices)
res_gt = fn_gt(a, b, row_indices, row_offsets, column_indices)
assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt, atol=1e-6)
@cuda_only
@pytest.mark.parametrize("nnz", [0, 4, 16, 20, 36])
def test_sddmm_csr_per_nnz(nnz):
device = torch.device("cuda")
B = 8
L, M, K = 1024, 1024, 32
a = torch.rand(B, L, K, device=device)
b = torch.rand(B, M, K, device=device)
mask = torch.zeros(L, M, dtype=torch.bool, device=device)
mask.view(-1)[: nnz - 1] = True
mask[-1, -1] = True
mask_csr = xformers.components.attention.core.SparseCS(mask, device)
row_indices = mask_csr.row_indices
row_offsets = mask_csr.row_offsets
column_indices = mask_csr.column_indices
fn = torch.ops.xformers.csr_sddmm
fn_gt = torch.ops.xformers.sddmm_sputnik
res = fn(a, b, row_indices, row_offsets, column_indices)
res_gt = fn_gt(a, b, row_indices, row_offsets, column_indices)
assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt, atol=1e-6)
@cuda_only
@pytest.mark.parametrize("prob", [0.5, 1])
@pytest.mark.parametrize("K", [32, 17])
@pytest.mark.parametrize("M", [30, 17])
@pytest.mark.parametrize("L", [30, 17])
def test_sddmm_coo(L, M, K, prob):
device = torch.device("cuda")
# TODO add more checks for different nnz
B = 8
a = torch.rand(B, L, K, device=device)
b = torch.rand(B, M, K, device=device)
mask = _create_random_sparsity(
torch.ones(B, L, M, dtype=torch.bool, device=device), prob
)
mask_csr = xformers.components.attention.core.SparseCS(mask, device)
row_indices = mask_csr.row_indices
row_offsets = mask_csr.row_offsets
column_indices = mask_csr.column_indices
fn = torch.ops.xformers.coo_sddmm
fn_gt = torch.ops.xformers.sddmm_sputnik
# convert from csr to coo
row_coo, _ = _csr_to_coo(L, M, row_offsets, column_indices)
res = fn(a, b, row_indices, row_coo, column_indices)
res_gt = fn_gt(a, b, row_indices, row_offsets, column_indices)
assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt, atol=1e-6)
@pytest.mark.parametrize("device", _devices)
def test_sddmm_sputnik_backward(device):
contiguous = True
B, L, M, K = 8, 10, 16, 32
prob = 0.5
a = torch.rand(B, L, K, device=device, requires_grad=True)
b = torch.rand(B, M, K, device=device).transpose(-2, -1).requires_grad_(True)
if not contiguous:
a = a.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_()
b = b.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_()
mask = _create_random_sparsity(
torch.ones(B, L, M, dtype=torch.bool, device=device), prob
)
mask_csr = xformers.components.attention.core.SparseCS(mask, device)
fn = xformers.components.attention.core._matmul_with_mask
mask = mask.to_sparse()
out_csr = fn(a, b, mask_csr)
out_csr.values.sum().backward()
grad_a = a.grad.clone()
grad_b = b.grad.clone()
a.grad = None
b.grad = None
# fn(a[None], b[None], mask).coalesce().values().sum().backward() # TODO check why this fails
fn(a, b, mask).to_dense().sum().backward()
assert torch.allclose(grad_a, a.grad, atol=1e-7)
assert torch.allclose(grad_b, b.grad, atol=1e-7)
@pytest.mark.parametrize("device", _devices)
def test_sparse_softmax_sputnik(device):
B, L = 8, 30
prob = 0.5
a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob)
a_csr = xformers.components.attention.core.SparseCS(a, device)
fn = xformers.components.attention.core._softmax
a = a.to_sparse()
res = fn(a_csr)
res_gt = fn(a)
res = res.to_dense()
res_gt = res_gt.to_dense()
assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt)
@pytest.mark.parametrize("device", _devices)
def test_sparse_softmax_sputnik_backward(device):
B, L = 8, 30
prob = 0.5
a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob)
a_csr = xformers.components.attention.core.SparseCS(a, device)
fn = xformers.components.attention.core._softmax
a = a.to_sparse()
a_csr.values.requires_grad_(True)
fn(a_csr).values.sum().backward()
grad_a = a_csr.values.grad.clone()
a.requires_grad_(True)
fn(a).coalesce().values().sum().backward()
assert torch.allclose(
grad_a, a.grad.coalesce().values().reshape_as(grad_a), atol=1e-7
)
@pytest.mark.parametrize("device", _devices)
def test_spmm_sputnik(device):
B, L, K = 8, 30, 32
prob = 0.5
a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob)
b = torch.rand(B, L, K, device=device)
a_csr = xformers.components.attention.core.SparseCS(a, device)
fn = xformers.components.attention.core.bmm
a = a.to_sparse()
res = fn(a_csr, b)
res_gt = fn(a, b)
res = res
res_gt = res_gt
assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt)
@pytest.mark.parametrize("device", _devices)
def test_spmm_sputnik_backward(device):
B, M, L, K = 8, 16, 30, 32
prob = 0.5
a = _create_random_sparsity(torch.rand(B, M, L, device=device), prob)
b = torch.rand(B, L, K, device=device)
b.requires_grad_(True)
a_csr = xformers.components.attention.core.SparseCS(a, device)
fn = xformers.components.attention.core.bmm
a = a.to_sparse()
a.requires_grad_(True)
a_csr.values.requires_grad_(True)
fn(a_csr, b).sum().backward()
grad_a = a_csr.values.grad.clone()
grad_b = b.grad.clone()
b.grad = None
fn(a, b).sum().backward()
assert torch.allclose(
grad_a, a.grad.coalesce().values().reshape_as(grad_a), atol=1e-7
)
assert torch.allclose(grad_b, b.grad, atol=1e-7)
@cuda_only
def test_csr_transpose():
B, L, K = 8, 30, 40
prob = 0.5
device = torch.device("cuda")
a = _create_random_sparsity(torch.rand(B, L, K, device=device), prob)
a_csr = xformers.components.attention.core.SparseCS(a, device)
res = a_csr.transpose()
res2 = res.transpose()
assert torch.allclose(res.to_dense(), a.transpose(-2, -1))
assert torch.allclose(res2.to_dense(), a)
@pytest.mark.parametrize("contiguous", [True, False])
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("prob", [0.95, 0.996]) # cover > 0.995
@pytest.mark.parametrize("N", [32, 64, 96]) # cover > 64
def test_sparse_bmm(device, contiguous, prob, N):
B, M = 8, 64
a = torch.rand(B, M, N, device=device)
a[a < prob] = 0
a = a.to_sparse()
b = torch.rand(B, N, M, device=device)
if not contiguous:
a = a + a
b = b.transpose(-2, -1).contiguous().transpose(-2, -1)
res = _sparse_bmm(a, b)
res_gt = _baseline_sparse_bmm(a, b)
assert torch.allclose(res, res_gt)
@pytest.mark.parametrize("contiguous", [True, False])
@pytest.mark.parametrize("device", _devices)
def test_sparse_bmm_backward(device, contiguous):
if device == "cuda":
# Skip test for now due to bug in torch 1.8
# See https://github.com/pytorch/pytorch/issues/54975
# Broken CUDA / torch 1.8 combination, awaiting an update
return
B, L, K = 8, 10, 16
prob = 0.5
a = torch.rand(B, L, K, device=device)
a[a < prob] = 0
a = a.to_sparse()
b = torch.rand(B, K, L, device=device, requires_grad=True)
if not contiguous:
a = a + a
b = b.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_()
a.requires_grad_(True)
def compute_grads(f):
out = f(a, b)
out.sum().backward()
compute_grads(_sparse_bmm)
grad_a = a.grad.clone().coalesce()
grad_b = b.grad.clone()
a.grad = None
b.grad = None
compute_grads(_baseline_sparse_bmm)
new_grad_a = a.grad.coalesce()
assert torch.allclose(grad_a.indices(), new_grad_a.indices())
assert torch.allclose(grad_a.values(), new_grad_a.values())
assert torch.allclose(grad_b, b.grad)
@pytest.mark.parametrize("device", _devices)
def test_sparse_coo_broadcast(device):
B, L, K = 8, 10, 16
prob = 0.5
a = torch.rand(L, K, device=device)
a[a < prob] = 0
a_sparse = a.to_sparse()
res = _broadcast_batch(a_sparse, B)
res_gt = a[None, :, :].expand(B, L, K)
assert torch.allclose(res.to_dense(), res_gt)