Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import functools | |
import pytest | |
import torch | |
import xformers | |
from xformers.components import MultiHeadDispatch | |
from xformers.components.attention import build_attention | |
from xformers.components.attention.attention_patterns import block_sparsify_tensor | |
def catch_oor(fn): | |
def fn_and_catch_oor(*args, **kwargs): | |
from triton import OutOfResources | |
try: | |
return fn(*args, **kwargs) | |
except OutOfResources as e: | |
pytest.skip(str(e)) | |
return None | |
return fn_and_catch_oor | |
# CREDITS: | |
# Tests from, very lightly changed | |
# https://github.com/openai/triton/blob/master/python/test/unit/operators/test_blocksparse.py | |
# Initially copied here folowing a fork from the matmul kernel | |
_triton_available = xformers._is_triton_available() | |
_matmul_types = [] | |
if _triton_available: | |
try: | |
import triton # noqa: F401 | |
from triton.ops.blocksparse import matmul as blocksparse_matmul | |
from triton.ops.blocksparse import softmax as blocksparse_softmax | |
from xformers.components.attention import BlockSparseAttention | |
_matmul_types = ["sdd", "dsd", "dds"] | |
except (ImportError, ModuleNotFoundError) as e: | |
import logging | |
logging.warning(f"Triton is not available: {e}. Some tests will be skipped") | |
_triton_available = False | |
def mask_tensor(x, mask, block, value=0): | |
ret = x.clone() | |
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)): | |
ret[:, h, i * block : (i + 1) * block, j * block : (j + 1) * block] = value | |
return ret | |
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=32, H=2, M=512, N=384, K=256): | |
# set seed | |
torch.random.manual_seed(0) | |
# create inputs | |
a = torch.randn( | |
(Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda" | |
) | |
b = torch.randn( | |
(Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda" | |
) | |
shape = { | |
"sdd": (M, N), | |
"dsd": (a.shape[2], a.shape[3]), | |
"dds": (b.shape[2], b.shape[3]), | |
}[MODE] | |
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK)) | |
# triton result | |
op = blocksparse_matmul( | |
layout, | |
BLOCK, | |
MODE, | |
trans_a=TRANS_A, | |
trans_b=TRANS_B, | |
device=torch.device("cuda"), | |
) | |
ra = block_sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a | |
rb = block_sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b | |
rc = op(ra, rb) | |
# torch result | |
ta = mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a | |
tb = mask_tensor(b, layout, BLOCK) if MODE == "dds" else b | |
ta = ta.transpose(2, 3) if TRANS_A else ta | |
tb = tb.transpose(2, 3) if TRANS_B else tb | |
tc = torch.matmul(ta, tb) | |
tc = mask_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc | |
tc = block_sparsify_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc | |
# compare | |
torch.testing.assert_close(rc, tc) | |
def test_softmax(BLOCK, WIDTH, DTYPE): | |
# set seed | |
torch.random.manual_seed(0) | |
Z, H, M, N = 2, 4, WIDTH, WIDTH | |
scale = 0.4 | |
# create inputs | |
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK)) | |
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda") | |
# triton result | |
op = blocksparse_softmax(layout, BLOCK, device=torch.device("cuda")) | |
tx = block_sparsify_tensor(x, layout, BLOCK) | |
ty = op(tx, scale=scale) | |
# torch result | |
rx = mask_tensor(x, layout, BLOCK, value=float("-inf")) | |
rx = rx[:, :, : (M // BLOCK) * BLOCK, : (M // BLOCK) * BLOCK] | |
ry = torch.softmax(rx * scale, -1) | |
ry = block_sparsify_tensor(ry, layout, BLOCK) | |
# compare | |
torch.testing.assert_close(ry, ty) | |
# 16, 32, | |
def test_attention_fwd_bwd( | |
block, | |
dtype, | |
input_scale=1.0, | |
scale=1 / 8.0, | |
n_ctx=384, | |
batch_size=2, | |
n_heads=2, | |
): | |
# inputs | |
head_dim = 64 | |
qkv_shape = (batch_size, n_heads, n_ctx, head_dim) | |
qkvs = [ | |
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True) | |
.to(dtype) | |
.cuda() | |
for _ in range(3) | |
] | |
def loss_fn(x): | |
return (x**2).mean() | |
# Triton: | |
n_blocks = n_ctx // block | |
layout = torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long) | |
query, key, value = [x.clone() for x in qkvs] | |
query.retain_grad() | |
key.retain_grad() | |
value.retain_grad() | |
if block not in [16, 32, 64, 128]: | |
# Check that unsupported dimensions are caught | |
with pytest.raises(AssertionError): | |
_ = BlockSparseAttention(layout, block) | |
else: | |
block_sparse_attention = BlockSparseAttention(layout, block) | |
attn_out = block_sparse_attention(q=query, k=key, v=value, scale=scale) | |
# ad hoc loss | |
loss = loss_fn(attn_out) | |
loss.backward() | |
grads = [query.grad, key.grad, value.grad] | |
# Torch version: | |
torch_q, torch_k, torch_v = [x.clone() for x in qkvs] | |
torch_q = torch_q * scale | |
torch_q.retain_grad() | |
torch_k.retain_grad() | |
torch_v.retain_grad() | |
scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k) | |
probs = torch.softmax(scores, dim=-1) | |
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v) | |
# ad hoc loss | |
torch_loss = loss_fn(torch_attn_out) | |
torch_loss.backward() | |
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad] | |
# comparison | |
torch.testing.assert_close( | |
loss, torch_loss, msg=f"Triton loss {loss} and torch loss {torch_loss}" | |
) | |
for g1, g2 in zip(grads, torch_grads): | |
torch.testing.assert_close( | |
torch.norm(g1), | |
torch.norm(g2), | |
msg=f"Triton grad {torch.norm(g1).item()} and torch grad {torch.norm(g2).item()}", | |
) | |
def test_blocksparse_attention_parity(dtype): | |
def _reset_seeds(): | |
torch.manual_seed(0) | |
seq = 64 | |
model = 128 | |
heads = 4 | |
block_size = 16 | |
batch_size = 2 | |
batched_dim = heads * batch_size | |
dim_head = model // heads | |
test_config = { | |
"dropout": 0.0, | |
"causal": False, | |
"seq_len": seq, | |
"num_heads": 4, | |
"dim_head": dim_head, | |
"block_size": block_size, | |
"layout": torch.ones(seq // block_size, seq // block_size, dtype=torch.long), | |
} | |
inputs = torch.rand(batched_dim, seq, model, device="cuda", dtype=dtype) | |
_reset_seeds() | |
test_config["name"] = "scaled_dot_product" | |
attention_sdp = build_attention(test_config) | |
multi_head_sdp = MultiHeadDispatch( | |
seq_len=seq, | |
dim_model=model, | |
residual_dropout=0.0, | |
num_heads=heads, | |
attention=attention_sdp, | |
).to(device=torch.device("cuda"), dtype=dtype) | |
r_sdp = multi_head_sdp(inputs, inputs, inputs) | |
_reset_seeds() | |
test_config["name"] = "blocksparse" | |
attention_blocksparse = build_attention(test_config) | |
multi_head_blocksparse = MultiHeadDispatch( | |
seq_len=seq, | |
dim_model=model, | |
residual_dropout=0.0, | |
num_heads=heads, | |
attention=attention_blocksparse, | |
).to(device=torch.device("cuda"), dtype=dtype) | |
r_blocksparse = multi_head_blocksparse(inputs, inputs, inputs) | |
torch.testing.assert_close(r_sdp, r_blocksparse, atol=5e-5, rtol=6e-3) | |