Flexstorydiff / xformers /tests /test_triton_blocksparse.py
FlexTheAi's picture
Upload folder using huggingface_hub
e202b16 verified
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
import pytest
import torch
import xformers
from xformers.components import MultiHeadDispatch
from xformers.components.attention import build_attention
from xformers.components.attention.attention_patterns import block_sparsify_tensor
def catch_oor(fn):
@functools.wraps(fn)
def fn_and_catch_oor(*args, **kwargs):
from triton import OutOfResources
try:
return fn(*args, **kwargs)
except OutOfResources as e:
pytest.skip(str(e))
return None
return fn_and_catch_oor
# CREDITS:
# Tests from, very lightly changed
# https://github.com/openai/triton/blob/master/python/test/unit/operators/test_blocksparse.py
# Initially copied here folowing a fork from the matmul kernel
_triton_available = xformers._is_triton_available()
_matmul_types = []
if _triton_available:
try:
import triton # noqa: F401
from triton.ops.blocksparse import matmul as blocksparse_matmul
from triton.ops.blocksparse import softmax as blocksparse_softmax
from xformers.components.attention import BlockSparseAttention
_matmul_types = ["sdd", "dsd", "dds"]
except (ImportError, ModuleNotFoundError) as e:
import logging
logging.warning(f"Triton is not available: {e}. Some tests will be skipped")
_triton_available = False
def mask_tensor(x, mask, block, value=0):
ret = x.clone()
for h, i, j in zip(*(mask == 0).nonzero(as_tuple=True)):
ret[:, h, i * block : (i + 1) * block, j * block : (j + 1) * block] = value
return ret
@pytest.mark.skipif(not _triton_available, reason="Triton requires a recent CUDA gpu")
@pytest.mark.parametrize("MODE", _matmul_types)
@pytest.mark.parametrize("TRANS_A", [False, True])
@pytest.mark.parametrize("TRANS_B", [False, True])
@pytest.mark.parametrize("BLOCK", [16, 32, 64])
@pytest.mark.parametrize("DTYPE", [torch.float16])
@catch_oor
def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=32, H=2, M=512, N=384, K=256):
# set seed
torch.random.manual_seed(0)
# create inputs
a = torch.randn(
(Z, H, K, M) if TRANS_A else (Z, H, M, K), dtype=DTYPE, device="cuda"
)
b = torch.randn(
(Z, H, N, K) if TRANS_B else (Z, H, K, N), dtype=DTYPE, device="cuda"
)
shape = {
"sdd": (M, N),
"dsd": (a.shape[2], a.shape[3]),
"dds": (b.shape[2], b.shape[3]),
}[MODE]
layout = torch.randint(2, (H, shape[0] // BLOCK, shape[1] // BLOCK))
# triton result
op = blocksparse_matmul(
layout,
BLOCK,
MODE,
trans_a=TRANS_A,
trans_b=TRANS_B,
device=torch.device("cuda"),
)
ra = block_sparsify_tensor(a, layout, BLOCK) if MODE == "dsd" else a
rb = block_sparsify_tensor(b, layout, BLOCK) if MODE == "dds" else b
rc = op(ra, rb)
# torch result
ta = mask_tensor(a, layout, BLOCK) if MODE == "dsd" else a
tb = mask_tensor(b, layout, BLOCK) if MODE == "dds" else b
ta = ta.transpose(2, 3) if TRANS_A else ta
tb = tb.transpose(2, 3) if TRANS_B else tb
tc = torch.matmul(ta, tb)
tc = mask_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc
tc = block_sparsify_tensor(tc, layout, BLOCK) if MODE == "sdd" else tc
# compare
torch.testing.assert_close(rc, tc)
@pytest.mark.skipif(not _triton_available, reason="Triton requires a recent CUDA gpu")
@pytest.mark.parametrize("BLOCK", [32, 128])
@pytest.mark.parametrize("WIDTH", [256, 576, 1024, 1792])
@pytest.mark.parametrize("DTYPE", [torch.float16, torch.float32])
def test_softmax(BLOCK, WIDTH, DTYPE):
# set seed
torch.random.manual_seed(0)
Z, H, M, N = 2, 4, WIDTH, WIDTH
scale = 0.4
# create inputs
layout = torch.randint(2, (H, M // BLOCK, N // BLOCK))
x = torch.randn((Z, H, M, N), dtype=DTYPE, requires_grad=True, device="cuda")
# triton result
op = blocksparse_softmax(layout, BLOCK, device=torch.device("cuda"))
tx = block_sparsify_tensor(x, layout, BLOCK)
ty = op(tx, scale=scale)
# torch result
rx = mask_tensor(x, layout, BLOCK, value=float("-inf"))
rx = rx[:, :, : (M // BLOCK) * BLOCK, : (M // BLOCK) * BLOCK]
ry = torch.softmax(rx * scale, -1)
ry = block_sparsify_tensor(ry, layout, BLOCK)
# compare
torch.testing.assert_close(ry, ty)
@pytest.mark.skipif(not _triton_available, reason="Triton requires a recent CUDA gpu")
@pytest.mark.parametrize("block", [32, 43, 128]) # 16, 32,
@pytest.mark.parametrize("dtype", [torch.float16])
def test_attention_fwd_bwd(
block,
dtype,
input_scale=1.0,
scale=1 / 8.0,
n_ctx=384,
batch_size=2,
n_heads=2,
):
# inputs
head_dim = 64
qkv_shape = (batch_size, n_heads, n_ctx, head_dim)
qkvs = [
torch.nn.Parameter(input_scale * torch.randn(qkv_shape), requires_grad=True)
.to(dtype)
.cuda()
for _ in range(3)
]
def loss_fn(x):
return (x**2).mean()
# Triton:
n_blocks = n_ctx // block
layout = torch.ones([n_heads, n_blocks, n_blocks], dtype=torch.long)
query, key, value = [x.clone() for x in qkvs]
query.retain_grad()
key.retain_grad()
value.retain_grad()
if block not in [16, 32, 64, 128]:
# Check that unsupported dimensions are caught
with pytest.raises(AssertionError):
_ = BlockSparseAttention(layout, block)
else:
block_sparse_attention = BlockSparseAttention(layout, block)
attn_out = block_sparse_attention(q=query, k=key, v=value, scale=scale)
# ad hoc loss
loss = loss_fn(attn_out)
loss.backward()
grads = [query.grad, key.grad, value.grad]
# Torch version:
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
torch_q = torch_q * scale
torch_q.retain_grad()
torch_k.retain_grad()
torch_v.retain_grad()
scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k)
probs = torch.softmax(scores, dim=-1)
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)
# ad hoc loss
torch_loss = loss_fn(torch_attn_out)
torch_loss.backward()
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]
# comparison
torch.testing.assert_close(
loss, torch_loss, msg=f"Triton loss {loss} and torch loss {torch_loss}"
)
for g1, g2 in zip(grads, torch_grads):
torch.testing.assert_close(
torch.norm(g1),
torch.norm(g2),
msg=f"Triton grad {torch.norm(g1).item()} and torch grad {torch.norm(g2).item()}",
)
@pytest.mark.skipif(not _triton_available, reason="Triton requires a recent CUDA gpu")
@pytest.mark.parametrize("dtype", [torch.float16])
def test_blocksparse_attention_parity(dtype):
def _reset_seeds():
torch.manual_seed(0)
seq = 64
model = 128
heads = 4
block_size = 16
batch_size = 2
batched_dim = heads * batch_size
dim_head = model // heads
test_config = {
"dropout": 0.0,
"causal": False,
"seq_len": seq,
"num_heads": 4,
"dim_head": dim_head,
"block_size": block_size,
"layout": torch.ones(seq // block_size, seq // block_size, dtype=torch.long),
}
inputs = torch.rand(batched_dim, seq, model, device="cuda", dtype=dtype)
_reset_seeds()
test_config["name"] = "scaled_dot_product"
attention_sdp = build_attention(test_config)
multi_head_sdp = MultiHeadDispatch(
seq_len=seq,
dim_model=model,
residual_dropout=0.0,
num_heads=heads,
attention=attention_sdp,
).to(device=torch.device("cuda"), dtype=dtype)
r_sdp = multi_head_sdp(inputs, inputs, inputs)
_reset_seeds()
test_config["name"] = "blocksparse"
attention_blocksparse = build_attention(test_config)
multi_head_blocksparse = MultiHeadDispatch(
seq_len=seq,
dim_model=model,
residual_dropout=0.0,
num_heads=heads,
attention=attention_blocksparse,
).to(device=torch.device("cuda"), dtype=dtype)
r_blocksparse = multi_head_blocksparse(inputs, inputs, inputs)
torch.testing.assert_close(r_sdp, r_blocksparse, atol=5e-5, rtol=6e-3)