Spaces:
Runtime error
Runtime error
File size: 14,389 Bytes
e202b16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
# needed to register custom ops
import xformers # noqa: F401
import xformers.components.attention.core
from xformers.components.attention._sputnik_sparse import _csr_to_coo
from xformers.components.attention.core import (
_broadcast_batch,
_create_random_sparsity,
_sparse_bmm,
)
cuda_only = pytest.mark.skipif(
not torch.cuda.is_available() or not torch.version.cuda, reason="requires CUDA"
)
_devices = (
["cpu", "cuda"] if torch.cuda.is_available() and torch.version.cuda else ["cpu"]
)
def _baseline_matmul_with_sparse_mask(
a: torch.Tensor, b: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
assert a.ndim == b.ndim
assert mask.ndim == a.ndim
assert a.shape[-1] == b.shape[-2]
assert a.shape[-2] == mask.shape[-2], f"{a.shape}, {mask.shape}"
assert b.shape[-1] == mask.shape[-1], f"{b.shape}, {mask.shape}"
assert a.shape[:-2] == b.shape[:-2], f"{a.shape}, {b.shape}"
assert a.shape[:-2] == mask.shape[:-2], f"{a.shape}, {mask.shape}"
idxs = mask.indices().unbind()
b = b.transpose(-2, -1)
# compute matmul for elements within the mask
val = (a[idxs[:-2] + (idxs[-2], slice(None))] * b[idxs[:-2] + (idxs[-1], slice(None))]).sum(-1) # type: ignore
out_shape = a.shape[:-1] + (b.shape[-2],)
res = torch.sparse_coo_tensor(torch.stack(idxs), val, out_shape)
return res
def _baseline_matmul_with_dense_mask(
a: torch.Tensor, b: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
res = a @ b
res[~mask] = float("-inf")
return res
def _baseline_sparse_bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
# need to use torch.sparse.mm to get gradients wrt sparse matrix a
# TODO implement this in C++ / CUDA as this is slow!
out = []
for ai, bi in zip(a, b):
out.append(torch.sparse.mm(ai, bi))
return torch.stack(out, dim=0)
@pytest.mark.parametrize("is_sparse", [True, False])
@pytest.mark.parametrize("contiguous", [True, False])
@pytest.mark.parametrize("device", _devices)
def test_matmul_with_mask(device, contiguous, is_sparse):
B, L, K = 8, 30, 32
prob = 0.5
a = torch.rand(B, L, K, device=device)
b = torch.rand(B, K, L, device=device)
if not contiguous:
a = a.transpose(-2, -1).contiguous().transpose(-2, -1)
b = b.transpose(-2, -1).contiguous().transpose(-2, -1)
mask = torch.rand(B, L, L, device=device) > prob
fn = torch.ops.xformers.matmul_with_mask
fn_gt = _baseline_matmul_with_dense_mask
if is_sparse:
mask = mask.to_sparse()
fn_gt = _baseline_matmul_with_sparse_mask
res = fn(a, b, mask)
res_gt = fn_gt(a, b, mask)
if is_sparse:
res = res.to_dense()
res_gt = res_gt.to_dense()
assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt)
@pytest.mark.parametrize("is_sparse", [True, False])
@pytest.mark.parametrize("contiguous", [True, False])
@pytest.mark.parametrize("device", _devices)
def test_matmul_with_mask_backward(device, contiguous, is_sparse):
if device == "cuda" and is_sparse is False:
# Skip test for now due to bug in torch 1.8
# See https://github.com/pytorch/pytorch/issues/54975
# Broken CUDA / torch 1.8 combination, awaiting an update
return
B, L, K = 8, 10, 16
prob = 0.5
a = torch.rand(B, L, K, device=device, requires_grad=True)
b = torch.rand(B, K, L, device=device, requires_grad=True)
if not contiguous:
a = a.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_()
b = b.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_()
mask = torch.rand(B, L, L, device=device) > prob
fn = torch.ops.xformers.matmul_with_mask
fn_gt = _baseline_matmul_with_dense_mask
if is_sparse:
mask = mask.to_sparse()
fn_gt = _baseline_matmul_with_sparse_mask
def compute_grads(f):
out = f(a, b, mask)
if is_sparse:
out = out.to_dense()
out.sum().backward()
compute_grads(fn)
grad_a = a.grad.clone()
grad_b = b.grad.clone()
a.grad = None
b.grad = None
compute_grads(fn_gt)
assert torch.allclose(grad_a, a.grad)
assert torch.allclose(grad_b, b.grad)
@pytest.mark.parametrize("device", _devices)
def test_sddmm_sputnik(device):
B, L, M, K = 8, 30, 16, 32
prob = 0.5
a = torch.rand(B, L, K, device=device)
b = torch.rand(B, M, K, device=device).transpose(-2, -1)
mask = _create_random_sparsity(
torch.ones(B, L, M, dtype=torch.bool, device=device), prob
)
mask_csr = xformers.components.attention.core.SparseCS(mask, device)
fn = xformers.components.attention.core._matmul_with_mask
mask = mask.to_sparse()
res = fn(a, b, mask_csr)
res_gt = fn(a, b, mask)
res = res.to_dense()
res_gt = res_gt.to_dense()
assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt)
@cuda_only
@pytest.mark.parametrize("prob", [0.5, 1])
@pytest.mark.parametrize("K", [32, 17])
@pytest.mark.parametrize("M", [30, 17])
@pytest.mark.parametrize("L", [30, 17])
def test_sddmm_csr(L, M, K, prob):
device = torch.device("cuda")
# TODO add more checks for different nnz
B = 8
a = torch.rand(B, L, K, device=device)
b = torch.rand(B, M, K, device=device)
mask = _create_random_sparsity(
torch.ones(B, L, M, dtype=torch.bool, device=device), prob
)
mask_csr = xformers.components.attention.core.SparseCS(mask, device)
row_indices = mask_csr.row_indices
row_offsets = mask_csr.row_offsets
column_indices = mask_csr.column_indices
fn = torch.ops.xformers.csr_sddmm
fn_gt = torch.ops.xformers.sddmm_sputnik
res = fn(a, b, row_indices, row_offsets, column_indices)
res_gt = fn_gt(a, b, row_indices, row_offsets, column_indices)
assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt, atol=1e-6)
@cuda_only
@pytest.mark.parametrize("nnz", [0, 4, 16, 20, 36])
def test_sddmm_csr_per_nnz(nnz):
device = torch.device("cuda")
B = 8
L, M, K = 1024, 1024, 32
a = torch.rand(B, L, K, device=device)
b = torch.rand(B, M, K, device=device)
mask = torch.zeros(L, M, dtype=torch.bool, device=device)
mask.view(-1)[: nnz - 1] = True
mask[-1, -1] = True
mask_csr = xformers.components.attention.core.SparseCS(mask, device)
row_indices = mask_csr.row_indices
row_offsets = mask_csr.row_offsets
column_indices = mask_csr.column_indices
fn = torch.ops.xformers.csr_sddmm
fn_gt = torch.ops.xformers.sddmm_sputnik
res = fn(a, b, row_indices, row_offsets, column_indices)
res_gt = fn_gt(a, b, row_indices, row_offsets, column_indices)
assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt, atol=1e-6)
@cuda_only
@pytest.mark.parametrize("prob", [0.5, 1])
@pytest.mark.parametrize("K", [32, 17])
@pytest.mark.parametrize("M", [30, 17])
@pytest.mark.parametrize("L", [30, 17])
def test_sddmm_coo(L, M, K, prob):
device = torch.device("cuda")
# TODO add more checks for different nnz
B = 8
a = torch.rand(B, L, K, device=device)
b = torch.rand(B, M, K, device=device)
mask = _create_random_sparsity(
torch.ones(B, L, M, dtype=torch.bool, device=device), prob
)
mask_csr = xformers.components.attention.core.SparseCS(mask, device)
row_indices = mask_csr.row_indices
row_offsets = mask_csr.row_offsets
column_indices = mask_csr.column_indices
fn = torch.ops.xformers.coo_sddmm
fn_gt = torch.ops.xformers.sddmm_sputnik
# convert from csr to coo
row_coo, _ = _csr_to_coo(L, M, row_offsets, column_indices)
res = fn(a, b, row_indices, row_coo, column_indices)
res_gt = fn_gt(a, b, row_indices, row_offsets, column_indices)
assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt, atol=1e-6)
@pytest.mark.parametrize("device", _devices)
def test_sddmm_sputnik_backward(device):
contiguous = True
B, L, M, K = 8, 10, 16, 32
prob = 0.5
a = torch.rand(B, L, K, device=device, requires_grad=True)
b = torch.rand(B, M, K, device=device).transpose(-2, -1).requires_grad_(True)
if not contiguous:
a = a.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_()
b = b.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_()
mask = _create_random_sparsity(
torch.ones(B, L, M, dtype=torch.bool, device=device), prob
)
mask_csr = xformers.components.attention.core.SparseCS(mask, device)
fn = xformers.components.attention.core._matmul_with_mask
mask = mask.to_sparse()
out_csr = fn(a, b, mask_csr)
out_csr.values.sum().backward()
grad_a = a.grad.clone()
grad_b = b.grad.clone()
a.grad = None
b.grad = None
# fn(a[None], b[None], mask).coalesce().values().sum().backward() # TODO check why this fails
fn(a, b, mask).to_dense().sum().backward()
assert torch.allclose(grad_a, a.grad, atol=1e-7)
assert torch.allclose(grad_b, b.grad, atol=1e-7)
@pytest.mark.parametrize("device", _devices)
def test_sparse_softmax_sputnik(device):
B, L = 8, 30
prob = 0.5
a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob)
a_csr = xformers.components.attention.core.SparseCS(a, device)
fn = xformers.components.attention.core._softmax
a = a.to_sparse()
res = fn(a_csr)
res_gt = fn(a)
res = res.to_dense()
res_gt = res_gt.to_dense()
assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt)
@pytest.mark.parametrize("device", _devices)
def test_sparse_softmax_sputnik_backward(device):
B, L = 8, 30
prob = 0.5
a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob)
a_csr = xformers.components.attention.core.SparseCS(a, device)
fn = xformers.components.attention.core._softmax
a = a.to_sparse()
a_csr.values.requires_grad_(True)
fn(a_csr).values.sum().backward()
grad_a = a_csr.values.grad.clone()
a.requires_grad_(True)
fn(a).coalesce().values().sum().backward()
assert torch.allclose(
grad_a, a.grad.coalesce().values().reshape_as(grad_a), atol=1e-7
)
@pytest.mark.parametrize("device", _devices)
def test_spmm_sputnik(device):
B, L, K = 8, 30, 32
prob = 0.5
a = _create_random_sparsity(torch.rand(B, L, L, device=device), prob)
b = torch.rand(B, L, K, device=device)
a_csr = xformers.components.attention.core.SparseCS(a, device)
fn = xformers.components.attention.core.bmm
a = a.to_sparse()
res = fn(a_csr, b)
res_gt = fn(a, b)
res = res
res_gt = res_gt
assert res.dtype == res_gt.dtype
assert torch.allclose(res, res_gt)
@pytest.mark.parametrize("device", _devices)
def test_spmm_sputnik_backward(device):
B, M, L, K = 8, 16, 30, 32
prob = 0.5
a = _create_random_sparsity(torch.rand(B, M, L, device=device), prob)
b = torch.rand(B, L, K, device=device)
b.requires_grad_(True)
a_csr = xformers.components.attention.core.SparseCS(a, device)
fn = xformers.components.attention.core.bmm
a = a.to_sparse()
a.requires_grad_(True)
a_csr.values.requires_grad_(True)
fn(a_csr, b).sum().backward()
grad_a = a_csr.values.grad.clone()
grad_b = b.grad.clone()
b.grad = None
fn(a, b).sum().backward()
assert torch.allclose(
grad_a, a.grad.coalesce().values().reshape_as(grad_a), atol=1e-7
)
assert torch.allclose(grad_b, b.grad, atol=1e-7)
@cuda_only
def test_csr_transpose():
B, L, K = 8, 30, 40
prob = 0.5
device = torch.device("cuda")
a = _create_random_sparsity(torch.rand(B, L, K, device=device), prob)
a_csr = xformers.components.attention.core.SparseCS(a, device)
res = a_csr.transpose()
res2 = res.transpose()
assert torch.allclose(res.to_dense(), a.transpose(-2, -1))
assert torch.allclose(res2.to_dense(), a)
@pytest.mark.parametrize("contiguous", [True, False])
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("prob", [0.95, 0.996]) # cover > 0.995
@pytest.mark.parametrize("N", [32, 64, 96]) # cover > 64
def test_sparse_bmm(device, contiguous, prob, N):
B, M = 8, 64
a = torch.rand(B, M, N, device=device)
a[a < prob] = 0
a = a.to_sparse()
b = torch.rand(B, N, M, device=device)
if not contiguous:
a = a + a
b = b.transpose(-2, -1).contiguous().transpose(-2, -1)
res = _sparse_bmm(a, b)
res_gt = _baseline_sparse_bmm(a, b)
assert torch.allclose(res, res_gt)
@pytest.mark.parametrize("contiguous", [True, False])
@pytest.mark.parametrize("device", _devices)
def test_sparse_bmm_backward(device, contiguous):
if device == "cuda":
# Skip test for now due to bug in torch 1.8
# See https://github.com/pytorch/pytorch/issues/54975
# Broken CUDA / torch 1.8 combination, awaiting an update
return
B, L, K = 8, 10, 16
prob = 0.5
a = torch.rand(B, L, K, device=device)
a[a < prob] = 0
a = a.to_sparse()
b = torch.rand(B, K, L, device=device, requires_grad=True)
if not contiguous:
a = a + a
b = b.detach().transpose(-2, -1).contiguous().transpose(-2, -1).requires_grad_()
a.requires_grad_(True)
def compute_grads(f):
out = f(a, b)
out.sum().backward()
compute_grads(_sparse_bmm)
grad_a = a.grad.clone().coalesce()
grad_b = b.grad.clone()
a.grad = None
b.grad = None
compute_grads(_baseline_sparse_bmm)
new_grad_a = a.grad.coalesce()
assert torch.allclose(grad_a.indices(), new_grad_a.indices())
assert torch.allclose(grad_a.values(), new_grad_a.values())
assert torch.allclose(grad_b, b.grad)
@pytest.mark.parametrize("device", _devices)
def test_sparse_coo_broadcast(device):
B, L, K = 8, 10, 16
prob = 0.5
a = torch.rand(L, K, device=device)
a[a < prob] = 0
a_sparse = a.to_sparse()
res = _broadcast_batch(a_sparse, B)
res_gt = a[None, :, :].expand(B, L, K)
assert torch.allclose(res.to_dense(), res_gt)
|