Flexstorydiff / xformers /tests /test_mem_eff_attention.py
FlexTheAi's picture
Upload folder using huggingface_hub
e202b16 verified
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
import random
from functools import partial
from typing import Any, List, Optional, Sequence, Tuple, Type, TypeVar, Union
import pytest
import torch
import torch.nn.functional as F
from scipy.stats import binomtest
from torch.utils.checkpoint import checkpoint
import xformers.ops
from xformers.attn_bias_utils import create_attn_bias
from xformers.ops import fmha
from xformers.ops.fmha import ALL_BW_OPS, ALL_FW_OPS
from xformers.ops.fmha.common import AttentionFwOpBase, AttentionOpBase
from xformers.ops.fmha.dispatch import _dispatch_fw_priority_list
from .utils import (
assert_allclose,
cuda_only,
disable_on_rocm,
disable_tf32,
pack_kv_cache,
ref_attention_bmhk_for_test,
ref_attention_for_test,
rocm_only,
)
compute_capability = (0, 0)
if torch.cuda.is_available():
compute_capability = torch.cuda.get_device_capability("cuda")
sm70_or_better_only = pytest.mark.skipif(
compute_capability < (7, 0), reason="requires sm70+"
)
sm75_or_better_only = pytest.mark.skipif(
compute_capability < (7, 5), reason="requires sm75+"
)
sm80_or_better_only = pytest.mark.skipif(
compute_capability < (8, 0), reason="requires sm80+"
)
skip_if_rocm = pytest.mark.skipif(
torch.version.hip is not None, reason="not supported on ROCm"
)
skip_if_pt_cutlass = pytest.mark.skipif(
fmha.cutlass.USE_TORCH_CUTLASS, reason="using PT cutlass"
)
_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
T = TypeVar(
"T", Type[fmha.common.AttentionFwOpBase], Type[fmha.common.AttentionBwOpBase]
)
logger = logging.getLogger("xformers")
def _filter_unsupported_ops(ops: Sequence[T]) -> List[T]:
return [
op
for op in ops
if (
"cpu" in op.SUPPORTED_DEVICES
or op.CUDA_MINIMUM_COMPUTE_CAPABILITY <= compute_capability
)
and op.is_available()
]
ALL_FW_OPS_NO_UNPADDED_LSE = _filter_unsupported_ops(ALL_FW_OPS)
ALL_FW_OPS = _filter_unsupported_ops(
ALL_FW_OPS
+ (
[fmha.flash.FlashFwUnpaddedLSE]
if fmha.flash.FLASH_SUPPORTS_UNPADDED_LSE
else []
)
)
ALL_BW_OPS = _filter_unsupported_ops(ALL_BW_OPS)
def sample_random_supported_fw(
inp: fmha.Inputs, seed: int
) -> Type[fmha.common.AttentionFwOpBase]:
r = random.Random(seed)
fw_ops = list(ALL_FW_OPS)
r.shuffle(fw_ops)
for op in fw_ops:
if op.supports(inp):
return op
raise NotImplementedError(f"Could not find a FW operator for: {inp}")
def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op):
shapes = []
for B in op._TEST_BATCH_SIZES:
for Mq in [32, 256]:
for Mkv in [32, 64, 256, 1024]:
for K in op._TEST_K:
shapes.append((B, Mq, Mkv, 1, K, K))
Mq = 256
Mkv = 128
K = 32
H = 1
# Weird values of parameters
for M in [2, 3, 15, 31, 32, 34, 68, 72, 90, 132, 136]:
shapes.append((B, M, Mkv, H, K, K))
shapes.append((B, Mq, M, H, K, K))
for _K in [1, 2, 3, 31, 34, 36, 38, 40, 64, 80, 160, 256 + 2, 256 + 8, 512]:
if _K <= op.SUPPORTED_MAX_K:
shapes.append((B, Mq, Mkv, H, _K, _K))
# Different value for K / Kv
if op.SUPPORTS_DIFFERENT_VALUE_EMBED:
for _K in [32, 36, 64, 256 + 8]:
shapes.append((B, Mq, Mkv, H, K, _K))
shapes.append((B, Mq, Mkv, H, _K, K))
# Exotic sizes
for K in op._TEST_K:
shapes.append((B, 16, 1024, H, K, K))
shapes.append((B, 1024, 16, H, K, K))
# Some number of heads
for H in [3, 5, 12]:
shapes.append((max(1, B // H), Mq, Mkv, H, K, K))
# Filter-out not supported shapes
shapes = [
shape
for shape in shapes
if len(
op.shape_not_supported_reasons(
Mq=shape[1], Mkv=shape[2], K=shape[4], Kv=shape[5]
)
)
== 0
]
# Add some random shapes
if op in [
fmha.cutlass.FwOp,
fmha.cutlass.BwOp,
fmha.flash.BwOp,
fmha.ck.FwOp,
]:
K_CHOICES = [8 * i for i in range(1, 256 // 8)]
r = random.Random(0)
found_count = 0
while found_count < 200:
B = r.randint(1, 400)
Mq = r.randint(1, 500)
Mkv = r.randint(1, 500)
H = r.randint(2, 11)
B = max(B // H, 1)
K = r.choice(K_CHOICES)
Kv = r.choice(K_CHOICES)
if not op.SUPPORTS_DIFFERENT_VALUE_EMBED:
Kv = K
if len(op.shape_not_supported_reasons(Mq, Mkv, K, Kv)):
continue
found_count += 1
shapes.append((B, Mq, Mkv, H, K, Kv))
return shapes
def make_id(op, device, dtype, bias_type, *shape):
return (
f"{op.NAME}-{device}-{str(dtype)}-{bias_type.__name__}"
f"-{'-'.join([str(s) for s in shape])}"
)
def _generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(
ops_list: Sequence[Type[fmha.AttentionOpBase]], max_shapes_per_op: int = 65000
):
r = random.Random(0)
combination = []
ids = []
for op in ops_list:
op_count = 0
# Sort list of masks, so it's deterministic across runs
LIST_MASKS = list(sorted(op.SUPPORTED_ATTN_BIAS_TYPES, key=lambda x: str(x)))
for shape in generate_test_shapes_B_Mq_Mkv_H_K_Kv(op):
has_one = False
for device in _devices:
if device not in op.SUPPORTED_DEVICES:
continue
for dtype in op.SUPPORTED_DTYPES:
bias_type = r.choice(LIST_MASKS)
# Avoid using too much memory
if bias_type not in [
type(None),
fmha.attn_bias.LowerTriangularMask,
]:
B, Mq, Mkv, H, K, Kv = shape
B = min(B, 12)
if bias_type in {
fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask,
fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask,
}:
Mq, Mkv = min(Mkv, Mq), max(Mkv, Mq) + 2
elif bias_type in {
fmha.attn_bias.BlockDiagonalCausalWithOffsetGappyKeysMask,
fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask,
fmha.attn_bias.BlockDiagonalPaddedKeysMask,
fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask,
}:
Mq, Mkv = min(Mkv, Mq), max(Mkv, Mq)
shape = (B, Mq, Mkv, H, K, Kv)
combination.append((op, device, dtype, bias_type, *shape))
ids.append(
f"{op.NAME}-{device}-{str(dtype)}-{bias_type.__name__}"
f"-{'-'.join([str(s) for s in shape])}"
)
has_one = True
if has_one:
op_count += 1
if op_count > max_shapes_per_op:
break
# Some specific shapes for which we want to run without any mask
bias_type = type(None)
for shape in (
# Some strides/dims don't fit on an uint16
(1, 128, 128, 300, 128, 128),
(13, 1, 67, 200, 8, 8),
(1, 1 + 2**16, 4, 1, 8, 8),
(1, 4, 1 + 2**16, 1, 8, 8),
# TODO: Some strides don't fit on an uint32
# Crashes on Flash, Errors on Cutlass
# (1, 1, 64000, 300, 128, 128)
):
for device in _devices:
if device not in op.SUPPORTED_DEVICES:
continue
for dtype in op.SUPPORTED_DTYPES:
combination.append((op, device, dtype, bias_type, *shape))
return {
"argvalues": combination,
"ids": [make_id(*c) for c in combination],
}
parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv = pytest.mark.parametrize(
"opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv",
**_generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(ALL_FW_OPS),
)
parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv_NO_UNPADDED_LSE = (
pytest.mark.parametrize(
"opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv",
**_generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(ALL_FW_OPS_NO_UNPADDED_LSE),
)
)
parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs = pytest.mark.parametrize(
"opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv",
**_generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(ALL_FW_OPS, max_shapes_per_op=1),
)
parametrize_opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv = pytest.mark.parametrize(
"opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv",
**_generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(ALL_BW_OPS),
)
parametrize_opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs = pytest.mark.parametrize(
"opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv",
**_generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(ALL_BW_OPS, max_shapes_per_op=1),
)
def _rand_partition(r: random.Random, total: int, n: int) -> List[int]:
# returns list of n nonnegative integers summing to total
idx = {0, total}
while len(idx) < n + 1:
idx.add(r.randint(1, total - 1))
s = sorted(idx)
return [e - b for b, e in zip(s[:-1], s[1:])]
def get_bias_grad(attn_bias, clear: bool = False) -> Optional[torch.Tensor]:
tensor_with_grad: Optional[torch.Tensor] = None
if isinstance(attn_bias, torch.Tensor):
tensor_with_grad = attn_bias
if tensor_with_grad is not None:
grad = tensor_with_grad.grad
if clear:
tensor_with_grad.grad = None
return grad
return None
def create_tensors(
op: Optional[Type[AttentionOpBase]],
device,
dtype,
attn_bias_type,
B,
q_len,
kv_len,
h,
k,
kv,
*,
attn_bias_requires_grad: bool = False,
fmt: str = "BMK",
g: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]:
torch.manual_seed(B * q_len + kv_len * k + kv)
mask_is_bottom_right = attn_bias_type is not None and issubclass(
attn_bias_type,
(
fmha.attn_bias.LowerTriangularFromBottomRightMask,
fmha.attn_bias.LowerTriangularFromBottomRightLocalAttentionMask,
fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask,
fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask,
fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask,
fmha.attn_bias.LocalAttentionFromBottomRightMask,
),
)
if mask_is_bottom_right and q_len > kv_len:
# Bottom-right attention and local-attention masks require q_len <= kv_len
kv_len = q_len
if attn_bias_type is not None and issubclass(
attn_bias_type,
(
fmha.attn_bias.PagedBlockDiagonalGappyKeysMask,
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask,
),
):
page_size_choices = [256, 512]
if op is not None and issubclass(op, fmha.triton_splitk.FwOp):
# TODO: enable small pages for flash attention when that's implemented
page_size_choices.extend([64, 128])
page_size = random.choice(page_size_choices)
kv_len_paged = (kv_len + page_size - 1) // page_size * page_size
else:
kv_len_paged = kv_len
page_size = None
scale = 3
if fmt == "BMK":
query = torch.randn((B * h, q_len, k), device=device, dtype=dtype)
key = torch.randn((B * h, kv_len_paged, k), device=device, dtype=dtype)
value = torch.randn((B * h, kv_len_paged, kv), device=device, dtype=dtype)
elif fmt == "BMHK":
query = torch.randn((B, q_len, h, k), device=device, dtype=dtype)
key = torch.randn((B, kv_len_paged, h, k), device=device, dtype=dtype)
value = torch.randn((B, kv_len_paged, h, kv), device=device, dtype=dtype)
else:
assert fmt == "BMGHK"
query = torch.randn((B, q_len, g, h, k), device=device, dtype=dtype)
key = torch.randn((B, kv_len_paged, g, 1, k), device=device, dtype=dtype)
value = torch.randn((B, kv_len_paged, g, 1, kv), device=device, dtype=dtype)
for x in [query, key, value]:
x.mul_(scale)
if fmt == "BMGHK":
# Expand - after the in-place mul
key = key.expand((B, kv_len_paged, g, h, k))
value = value.expand((B, kv_len_paged, g, h, k))
if fmt == "BMK" and not fmha.common._is_bias_type_supported_in_BMK(attn_bias_type):
attn_bias_type = None
attn_bias = None
if attn_bias_type is not None:
attn_bias = create_attn_bias(
attn_bias_type,
batch_size=B,
num_heads=h,
num_heads_groups=g,
q_len=q_len,
kv_len=kv_len,
dtype=dtype,
device=device,
requires_grad=attn_bias_requires_grad,
fmt=fmt,
op=op,
page_size=page_size,
)
if isinstance(
attn_bias,
(
fmha.attn_bias.BlockDiagonalMask,
fmha.attn_bias.BlockDiagonalGappyKeysMask,
fmha.attn_bias.BlockDiagonalPaddedKeysMask,
fmha.attn_bias.PagedBlockDiagonalGappyKeysMask,
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask,
),
):
query, key, value = [
x.reshape([1, -1, *x.shape[2:]]) for x in [query, key, value]
]
inputs = fmha.Inputs(query=query, key=key, value=value, attn_bias=attn_bias)
if op is not None:
reasons = op.not_supported_reasons(inputs)
if reasons:
err_msg = f"{op.NAME}: unsupported ({'/'.join(reasons)})"
# Ensure we free memory to avoid OOMs
del query, key, value, attn_bias, inputs
pytest.skip(err_msg)
return query, key, value, attn_bias
def bmhk2bmk(tensor) -> torch.Tensor:
return (
tensor.permute((0, 2, 1, 3))
.contiguous()
.view([tensor.shape[0] * tensor.shape[2], tensor.shape[1], tensor.shape[3]])
)
def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor:
return tensor.reshape([-1, num_heads, tensor.shape[1], tensor.shape[2]]).permute(
(0, 2, 1, 3)
)
def nanify_oob_seqlen(x: torch.Tensor) -> torch.Tensor:
align_to = 256
if x.shape[1] % align_to == 0:
return x
pad = [0, 0] * x.ndim
pad[-3] = align_to - (x.shape[1] % align_to)
x_pad = torch.nn.functional.pad(x, pad, value=math.nan)
return x_pad[:, : x.shape[1]]
@pytest.mark.parametrize("fmt", ["BMK", "BMHK"])
@pytest.mark.parametrize("packed", [False, True])
@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv_NO_UNPADDED_LSE
def test_forward(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, packed, fmt, **kwargs):
(
op,
device,
dtype,
bias_type,
batch_size,
q_len,
kv_len,
h,
k,
kv,
) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv
if packed and issubclass(
bias_type,
(
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask,
fmha.attn_bias.PagedBlockDiagonalGappyKeysMask,
),
):
pytest.skip(
"packed doesn't make sense with paged attention, since q has different shape than k/v"
)
if packed and not (k == kv and q_len == kv_len):
pytest.skip(
f"packed incompatible with `k ({k}) != kv ({kv})` or `q_len ({q_len}) != kv_len ({kv_len})`"
)
if fmt == "BMK" and not fmha.common._is_bias_type_supported_in_BMK(bias_type):
pytest.skip("BMK incompatible with this bias")
query, key, value, attn_bias = create_tensors(
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
fmt="BMHK" if packed else fmt,
**kwargs,
)
if packed:
c = torch.stack([query, key, value], 2)
if fmt == "BMK":
# bm3hk -> 3bhmk -> 3Bmk
c = c.permute(2, 0, 3, 1, 4).view([3, -1, q_len, k])
query, key, value = c[0], c[1], c[2]
# Re-create bias in the right format
attn_bias = create_attn_bias(
bias_type=bias_type,
batch_size=batch_size,
num_heads=h,
num_heads_groups=1,
q_len=q_len,
kv_len=kv_len,
device=device,
dtype=dtype,
requires_grad=False,
fmt=fmt,
op=op,
)
elif fmt == "BMHK":
# bm3hk -> 3 x bmhk
query, key, value = xformers.ops.unbind(c, 2)
else:
assert False, f"Unsupport fmt {fmt} with packing"
assert not query.is_contiguous()
out = xformers.ops.memory_efficient_attention_forward(
query, key, value, attn_bias, op=op
)
assert not out.isnan().any(), ("Output has NaNs", attn_bias)
out2 = xformers.ops.memory_efficient_attention_forward(
nanify_oob_seqlen(query),
nanify_oob_seqlen(key),
nanify_oob_seqlen(value),
attn_bias,
op=op,
)
assert not out2.isnan().any(), "Output has NaNs - most likely reading out-of-bounds"
assert torch.allclose(out, out2, atol=0.0, rtol=0.0), (
"Non-deterministic behavior",
attn_bias,
)
ref = ref_attention_for_test(query, key, value, attn_bias)
assert out.shape == ref.shape, out.shape
assert_allclose(
out.float(),
ref,
atol=op.ERROR_ATOL[dtype],
rtol=op.ERROR_RTOL.get(dtype, 1e-5),
)
@cuda_only
@pytest.mark.parametrize("k_len", [5, 6, 32])
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("kv_len", [128, 512])
@pytest.mark.parametrize("q_len", [128, 512])
def test_key_query_all_ones(q_len, kv_len, batch_size, k_len):
device = "cuda"
scale = 3
# composable kernel doesn't support fp32
dtype = torch.float16 if torch.version.hip else torch.float32
query = torch.ones((batch_size, q_len, k_len), device=device, dtype=dtype)
key = torch.ones((batch_size, kv_len, k_len), device=device, dtype=dtype)
value = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale
out = xformers.ops.memory_efficient_attention(query, key, value)
# this should be equivalent to the average over value
ref = value.mean(1, keepdim=True).expand_as(query)
assert_allclose(out, ref, atol=1e-5)
def _block_diag_reshape_lse(
lse: torch.Tensor, q_seqinfo: fmha.attn_bias._SeqLenInfo
) -> torch.Tensor:
"""LSE can be padded, let's remove the padding"""
parts = []
for slice, (start, end) in zip(lse.unbind(0), q_seqinfo.intervals()):
parts.append(slice[:, : end - start])
return torch.cat(parts, dim=1).unsqueeze(0)
@disable_tf32
@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv
def test_logsumexp(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv):
(
op,
device,
dtype,
bias_type,
batch_size,
q_len,
kv_len,
h,
k,
kv,
) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv
if op is fmha.ck.FwOp:
pytest.skip("logsumexp is not yet supported by ck-tiled fmha!")
query, key, value, attn_bias = create_tensors(
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
fmt="BMHK",
)
_out, lse = xformers.ops.memory_efficient_attention_forward_requires_grad(
query,
key,
value,
op=op,
attn_bias=attn_bias,
)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
attn = (query.float() / k**0.5) @ key.float().transpose(-2, -1)
if attn_bias is not None:
if isinstance(
attn_bias,
(fmha.attn_bias.AttentionBias, fmha.attn_bias.AttentionBiasSubTensor),
):
bias_shape = (1, 1, query.shape[2], key.shape[2])
tensor_bias = attn_bias.materialize(
bias_shape,
device=query.device,
dtype=torch.float32,
)
else:
assert type(attn_bias) is torch.Tensor
tensor_bias = attn_bias
attn = attn + tensor_bias.float()
ref_lse = attn.logsumexp(-1)
if isinstance(
attn_bias,
(
fmha.attn_bias.BlockDiagonalMask,
fmha.attn_bias.BlockDiagonalGappyKeysMask,
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask,
fmha.attn_bias.BlockDiagonalPaddedKeysMask,
),
) and issubclass(op, (fmha.flash.FwOp, fmha.cutlass.FwOp)):
# Sometimes LSE is returned in padded format, i.e. (B, H, MAX_LEN) instead of (H, TOTAL_LEN).
# Unpad to compare with the reference.
# This is the case for Flash Attention when UNPADDED_LSE=False and for CUTLASS.
if op.UNPADDED_LSE:
lse = lse.unsqueeze(0)
else:
lse = _block_diag_reshape_lse(lse, attn_bias.q_seqinfo)
if op is fmha.cutlass.FwOp:
# CUTLASS kernel pads the last dimention of LSE to 32
lse = lse[:, :, : ref_lse.shape[2]]
assert_allclose(lse, ref_lse, atol=2e-4)
@cuda_only
@pytest.mark.parametrize("op", [fmha.cutlass.FwOp, fmha.flash.FwOp])
def test_logsumexp_mqa(op):
if not op.is_available():
pytest.skip("not available")
if op.CUDA_MINIMUM_COMPUTE_CAPABILITY > compute_capability:
skip_reason = (
f"requires device with capability >= {op.CUDA_MINIMUM_COMPUTE_CAPABILITY} "
f"but your GPU has capability {compute_capability} (too old)"
)
pytest.skip(skip_reason)
dtype = torch.float16
s = 3
query = torch.randn([1, 1, 32, 128], dtype=dtype, device="cuda") * s
key = (torch.randn([1, 16, 1, 128], dtype=dtype, device="cuda") * s).expand(
-1, -1, 32, -1
)
value = (torch.randn([1, 16, 1, 128], dtype=dtype, device="cuda") * s).expand(
-1, -1, 32, -1
)
assert key.stride(2) == 0
_, lse = xformers.ops.memory_efficient_attention_forward_requires_grad(
query,
key,
value,
op=op,
)
query, key, value = [x[0].transpose(0, 1) for x in [query, key, value]]
attn = (query.float() / query.shape[-1] ** 0.5) @ key.float().transpose(-2, -1)
ref_lse = attn.logsumexp(-1)
assert_allclose(lse[0, :, 0], ref_lse[:, 0], atol=2e-4)
@disable_tf32
@pytest.mark.parametrize("fmt", ["BMK", "BMHK"])
@pytest.mark.parametrize("grad_out_contiguous", [False, True])
@parametrize_opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv
def test_backward(
opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
grad_out_contiguous,
fmt,
):
(
op_bw,
device,
dtype,
bias_type,
batch_size,
q_len,
kv_len,
h,
k,
kv,
) = opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv
attn_bias_requires_grad = (
random.Random(q_len + kv_len * batch_size).randint(0, 1) > 0
)
query, key, value, attn_bias = create_tensors(
*opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
attn_bias_requires_grad=attn_bias_requires_grad,
fmt=fmt,
)
# To understand why we do this, check the comment on the
# `AttentionBwOpBase` class
scale = None
if op_bw.SUPPORTS_CUSTOM_SCALE and query.shape[-1] < 32:
scale = (1 / 32) ** 0.5
op_fw = (
sample_random_supported_fw(
fmha.Inputs(query=query, key=key, value=value, attn_bias=attn_bias),
seed=q_len * kv + kv_len * k,
)
if op_bw != fmha.cutlass.BwOp
else fmha.cutlass.FwOp
)
if op_bw == fmha.ck.BwOp:
op_fw = fmha.ck.FwOp
if dtype == torch.bfloat16:
pytest.skip(
"CK Fmha backward for bfloat16 currently is not very accurate for some cases!"
)
if grad_out_contiguous is False:
pytest.skip("CK Fmha does not support contiguous layout for grad_out!")
if k % 2 != 0:
pytest.skip(
"CK Fmha currently requires the headdim size of query input be an even value!"
)
qkv = None
if (
fmt == "BMHK"
and query.shape[3] == value.shape[3]
and query.shape[1] == value.shape[1]
):
qkv = torch.stack([query, key, value], 2)
qkv.requires_grad_(True)
# bm3hk -> 3 x bmhk
query, key, value = xformers.ops.unbind(qkv, 2)
assert not query.is_contiguous()
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
if not op_bw.supports(fmha.Inputs(query, key, value, attn_bias)):
pytest.skip("inputs not supported")
out = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias, scale=scale, op=(op_fw, op_bw)
)
grad_out = torch.randn_like(out)
if grad_out_contiguous is False:
grad_out = torch.tensor([1.0], dtype=query.dtype, device=device)[
None, None, :
].expand_as(out)
out.backward(grad_out)
if qkv is None and op_bw == fmha.cutlass.BwOp:
assert query.stride() == query.grad.stride()
grads = []
if qkv is None:
grads = [query.grad, key.grad, value.grad]
query.grad = None
key.grad = None
value.grad = None
else:
grads = [qkv.grad]
qkv.grad = None
if attn_bias_requires_grad:
attn_bias_grad = get_bias_grad(attn_bias, clear=True)
if attn_bias_grad is not None:
grads.append(attn_bias_grad)
ref = ref_attention_for_test(query, key, value, attn_bias, scale=scale)
ref.backward(grad_out)
assert_allclose(
out.float(),
ref.float(),
"fw pass",
atol=op_fw.ERROR_ATOL[dtype],
rtol=op_fw.ERROR_RTOL[dtype],
)
del out
del grad_out
del ref
atol = op_bw.ERROR_ATOL[dtype]
rtol = op_bw.ERROR_RTOL[dtype]
grads_ref = []
grads_name = []
if qkv is None:
assert isinstance(query.grad, torch.Tensor)
assert isinstance(key.grad, torch.Tensor)
assert isinstance(value.grad, torch.Tensor)
grads_ref = [query.grad, key.grad, value.grad]
grads_name = ["query", "key", "value"]
else:
assert isinstance(qkv.grad, torch.Tensor)
grads_ref = [qkv.grad]
grads_name = ["qkv"]
if attn_bias_requires_grad:
attn_bias_grad = get_bias_grad(attn_bias)
if attn_bias_grad is not None:
grads_ref.append(attn_bias.grad)
grads_name.append("bias")
del query
del key
del value
del qkv
assert len(grads_ref) == len(
grads
), "Wrong number of gradients (maybe bias grad didn't backprop?)"
for name, calc_grad, ref_grad in zip(grads_name, grads, grads_ref):
assert_allclose(
calc_grad,
ref_grad,
msg=f"{op_fw.NAME}+{op_bw.NAME}:{name}",
atol=atol,
rtol=rtol,
)
def _vec_binom_test(x, n, p):
"""
vectorized implementation of scipy.stats.binom_test
this makes our tests much faster
reference: https://github.com/scipy/scipy/blob/v1.8.0/scipy/stats/_morestats.py#L2609-L2702
"""
import numpy as np
from scipy.stats import distributions
x = np.atleast_1d(x)
d = distributions.binom.pmf(x, n, p)[:, None]
rerr = 1 + 1e-7
# x < p * n case
i = np.arange(np.ceil(p * n), n + 1)
y = np.sum(distributions.binom.pmf(i, n, p) <= d * rerr, axis=1)
pval1 = distributions.binom.cdf(x, n, p) + distributions.binom.sf(n - y, n, p)
# other case
i = np.arange(np.floor(p * n) + 1)
y = np.sum(distributions.binom.pmf(i, n, p) <= d * rerr, axis=1)
pval2 = distributions.binom.cdf(y - 1, n, p) + distributions.binom.sf(x - 1, n, p)
pval = np.where(x < p * n, pval1, pval2)
pval = np.minimum(1.0, pval)
return pval
def _get_drop_mask(op, batch_size, q_len, kv_len, p, device):
if op == fmha.cutlass.FwOp:
mask = torch.empty((batch_size, 1, q_len, kv_len), device=device)
rand_uniform = torch.ops.xformers._cutlass_rand_uniform(p, mask)
mask = (rand_uniform > p).to(torch.float32)
mask = mask.reshape(batch_size, q_len, kv_len)
elif op == fmha.ck.FwOp:
mask = torch.empty((batch_size, 1, q_len, kv_len), device=device)
# rand_uniform is an int8_t tensor
rand_uniform = torch.ops.xformers._ck_rand_uniform(p, mask)
mask = (rand_uniform <= int((1.0 - p) * 255.0)).to(torch.float32)
mask = mask.reshape(batch_size, q_len, kv_len)
else:
mask = torch.empty((batch_size, q_len, kv_len), device=device)
mask = torch.ops.xformers._temp_dropout(mask, p)
return mask
@cuda_only
@pytest.mark.parametrize("attn_bias", [None, fmha.attn_bias.LowerTriangularMask()])
@pytest.mark.parametrize("seed", [42, 124])
@pytest.mark.parametrize("p", [0.3, 0.7])
@pytest.mark.parametrize("k_len", [32])
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("kv_len", [3, 15, 32, 33, 65])
@pytest.mark.parametrize("q_len", [2, 33])
@pytest.mark.parametrize(
"op",
ALL_FW_OPS_NO_UNPADDED_LSE,
ids=list(map(lambda t: t.NAME, ALL_FW_OPS_NO_UNPADDED_LSE)),
)
def test_dropout(op, q_len, kv_len, batch_size, k_len, p, seed, attn_bias):
device = "cuda"
scale = 3
dtype = torch.float
if torch.version.hip and op == fmha.ck.FwOp:
dtype = torch.float16
query = torch.randn((batch_size, q_len, k_len), device=device, dtype=dtype) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale
inputs_for_support_check = fmha.Inputs(query, key, value, attn_bias, p, None)
if not op.supports(inputs_for_support_check):
del query, key, value, attn_bias
pytest.skip(f"{op.NAME}: unsupported input")
torch.manual_seed(seed)
out = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias, p, op=(op, None)
)
torch.manual_seed(seed)
out2 = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias, p, op=(op, None)
)
assert_allclose(out, out2, "dropout reproducibility")
torch.manual_seed(seed)
mask = _get_drop_mask(op, batch_size, q_len, kv_len, p, device)
ref = ref_attention_for_test(query, key, value, attn_bias, mask, p)
if dtype is torch.float:
assert_allclose(out, ref, atol=2e-4), f"{(out - ref).abs().max()}"
else:
assert_allclose(out.float(), ref, atol=2.8e-2), f"{(out - ref).abs().max()}"
num_trials = 1000
p_val_tol = 1e-6
keep_prob = 1 - p
masks = []
for i in range(num_trials):
mask = _get_drop_mask(op, batch_size, q_len, kv_len, p, device)
masks.append(mask.clone().cpu())
masks = torch.stack(masks, dim=0)
p_value = binomtest(int(masks.sum()), masks.numel(), p=keep_prob).pvalue
assert p_value > p_val_tol, p_value
masks = masks.sum(0).flatten()
p_values = _vec_binom_test(masks, num_trials, p=keep_prob)
assert all(p_values > p_val_tol)
def _test_dropout_backward(q_len, kv_len, batch_size, k, p, op, dtype):
if dtype is torch.bfloat16 and compute_capability < (8, 0):
pytest.skip("bf16 requires Sm80")
if not op.is_available():
pytest.skip()
scale = 3
device = "cuda"
query = torch.randn((batch_size, q_len, k), device=device, dtype=dtype) * scale
key = torch.randn((batch_size, kv_len, k), device=device, dtype=dtype) * scale
value = torch.randn((batch_size, kv_len, k), device=device, dtype=dtype) * scale
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
grad_out = torch.ones_like(query)
assert op.supports(fmha.Inputs(query=query, key=key, value=value, p=p))
seed = 42
torch.manual_seed(seed)
out = xformers.ops.memory_efficient_attention(query, key, value, p=p, op=(op, None))
out.backward(grad_out)
grad_q = query.grad
grad_k = key.grad
grad_v = value.grad
query.grad = None
key.grad = None
value.grad = None
torch.manual_seed(seed)
mask = _get_drop_mask(op, batch_size, q_len, kv_len, p, device)
ref = ref_attention_for_test(query, key, value, None, mask, p)
ref.backward(grad_out)
atol, rtol = (
fmha.AttentionBwOpBase.ERROR_ATOL[dtype],
fmha.AttentionBwOpBase.ERROR_RTOL[dtype],
)
assert_allclose(
grad_v,
value.grad,
"grad_v",
atol=atol,
rtol=rtol,
)
# TODO: Investigate why precision is worse
if dtype in [torch.float16, torch.bfloat16]:
atol = atol * 2 + 0.15
rtol = rtol * 2
assert_allclose(
grad_q,
query.grad,
"grad_q",
atol=atol,
rtol=rtol,
)
assert_allclose(
grad_k,
key.grad,
"grad_k",
atol=atol,
rtol=rtol,
)
@cuda_only
@disable_tf32
@pytest.mark.parametrize("p", [0.3, 0.7])
@pytest.mark.parametrize("k", [5, 6, 32])
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("kv_len", [3, 15, 32, 33])
@pytest.mark.parametrize("q_len", [2, 33])
def test_dropout_backward_small_k(q_len, kv_len, batch_size, k, p):
_test_dropout_backward(
q_len, kv_len, batch_size, k, p, op=fmha.small_k.FwOp, dtype=torch.float32
)
@cuda_only
@disable_tf32
@pytest.mark.parametrize("p", [0.000001, 0.3, 0.7])
@pytest.mark.parametrize("k", [16, 128, 256])
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("kv_len", [3, 248, 256])
@pytest.mark.parametrize("q_len", [3, 248, 256])
@pytest.mark.parametrize("dt", ["f16", "bf16", "f32"])
def test_dropout_backward_cutlass(dt, q_len, kv_len, batch_size, k, p):
_test_dropout_backward(
q_len,
kv_len,
batch_size,
k,
p,
op=fmha.cutlass.FwOp,
dtype={"f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32}[dt],
)
@cuda_only
@pytest.mark.parametrize("p", [0.000001, 0.3, 0.7])
@pytest.mark.parametrize("k", [16, 64, 128])
@pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize("kv_len", [3, 248, 256])
@pytest.mark.parametrize("q_len", [3, 248, 256])
@pytest.mark.parametrize("dt", ["f16"])
def test_dropout_backward_ck(dt, q_len, kv_len, batch_size, k, p):
_test_dropout_backward(
q_len,
kv_len,
batch_size,
k,
p,
op=fmha.ck.FwOp,
dtype={"f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32}[dt],
)
@cuda_only
@disable_tf32
@disable_on_rocm
@pytest.mark.parametrize("k_len", [32])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("kv_len", [3 * 32])
@pytest.mark.parametrize("q_len", [3 * 32])
def test_memory_efficient_attention_full_block_masked(q_len, kv_len, batch_size, k_len):
device = "cuda"
op_fw = fmha.small_k.FwOp
op_bw = fmha.small_k.BwOp
scale = 3
query = torch.randn((batch_size, q_len, k_len), device=device) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device) * scale
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale
# in this case, most of the blocks in a row get masked
attn_bias = torch.full((3, 32), float("-inf"), device=device)
attn_bias[:2, :4] = 0
attn_bias = attn_bias.flatten()[None, None, :].expand(1, q_len, -1)
out = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias, op=(op_fw, op_bw)
)
ref = ref_attention_for_test(query, key, value, attn_bias)
assert_allclose(
out, ref, atol=op_fw.ERROR_ATOL[query.dtype], rtol=op_fw.ERROR_RTOL[query.dtype]
)
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
grad_out = torch.ones_like(query)
out = xformers.ops.memory_efficient_attention(query, key, value, attn_bias)
out.backward(grad_out)
grad_q = query.grad
grad_k = key.grad
grad_v = value.grad
query.grad = None
key.grad = None
value.grad = None
ref = ref_attention_for_test(query, key, value, attn_bias)
ref.backward(grad_out)
atol = op_bw.ERROR_ATOL[query.dtype]
rtol = op_bw.ERROR_RTOL[query.dtype]
assert_allclose(grad_q, query.grad, "grad_q", atol=atol, rtol=rtol)
assert_allclose(grad_k, key.grad, "grad_k", atol=atol, rtol=rtol)
assert_allclose(grad_v, value.grad, "grad_v", atol=atol, rtol=rtol)
@pytest.mark.parametrize("fmt", ["BMK", "BMHK"])
@parametrize_opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs
def test_lowlevel_api_shapes(opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt):
query, key, value, attn_bias = create_tensors(
*opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt=fmt
)
grad_out = torch.ones_like(query)
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
out, lse = xformers.ops.memory_efficient_attention_forward_requires_grad(
query, key, value, attn_bias
)
assert out.ndim == query.ndim
dq, dk, dv = xformers.ops.memory_efficient_attention_backward(
grad_out, out, lse, query, key, value, attn_bias
)
assert dq.shape == query.shape
assert dk.shape == key.shape
assert dv.shape == value.shape
@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs
def test_cuda_streams(
opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
):
(
op,
device,
dtype,
bias_type,
batch_size,
q_len,
kv_len,
h,
k,
kv,
) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv
if device != "cuda":
pytest.skip("Not CUDA")
bias_type = None
opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv = [
op,
device,
dtype,
bias_type,
batch_size,
q_len,
kv_len,
h,
k,
kv,
]
s_hipri = torch.cuda.Stream(priority=-1)
s_lopri = torch.cuda.Stream(priority=0)
query, key, value, attn_bias = create_tensors(
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt="BMHK"
)
torch.cuda.synchronize()
with torch.cuda.stream(s_lopri):
torch.cuda._sleep(100_000_000) # wait 100m cycles
query *= 2
s_hipri.wait_stream(s_lopri)
with torch.cuda.stream(s_hipri):
# If the kernel is scheduled in the main stream
# `query * 2` has not been executed yet
out = xformers.ops.memory_efficient_attention(query, key, value, op=(op, None))
# Test that `s_lopri` is still sleeping
# and that `query *= 2` has not been executed yet
query2_main_stream = query * 2
torch.cuda.synchronize()
# TODO: Figure out why this is failing sometimes
# The sleep timer seems to be high enough already ...
# assert torch.allclose(query2_main_stream, query), "Need to increase sleep time"
del query2_main_stream
ref = ref_attention_for_test(query, key, value)
assert out.shape == ref.shape, out.shape
assert_allclose(
out.float(),
ref.float(),
atol=op.ERROR_ATOL[dtype],
rtol=op.ERROR_RTOL.get(dtype, 1e-5),
)
@disable_tf32
@parametrize_opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs
def test_custom_scale(opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv):
p = 0.0
scale = 0.1
(
op_bw,
device,
dtype,
_,
B,
q_len,
kv_len,
H,
k,
Kv,
) = opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv
torch.manual_seed(q_len + kv_len + k)
if device != "cuda":
pytest.skip("Not CUDA")
query, key, value, attn_bias = create_tensors(
*opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt="BMK"
)
inputs = fmha.Inputs(
query=query, key=key, value=value, attn_bias=attn_bias, scale=scale
)
op_fw = sample_random_supported_fw(inputs, seed=q_len * k + kv_len * k)
grad_out = query.new_ones(B * H, q_len, Kv)
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
reasons = op_fw.not_supported_reasons(inputs)
if reasons:
pytest.skip(f"{op_fw.NAME}: unsupported ({'/'.join(reasons)})")
reasons = op_bw.not_supported_reasons(inputs)
if reasons:
pytest.skip(f"{op_bw.NAME}: unsupported ({'/'.join(reasons)})")
# NOTE: we still need to scale the inputs to not blowup
# the pre-softmax values (numerical stability)
s = k**-0.5
out = xformers.ops.memory_efficient_attention(
query * s, key, value, attn_bias, p, scale, op=(op_fw, op_bw)
)
out.backward(grad_out)
grad_q, grad_k, grad_v = query.grad, key.grad, value.grad
query.grad = key.grad = value.grad = None
ref = ref_attention_for_test(query * s, key, value, attn_bias, None, p, scale)
ref.backward(grad_out)
ref_grad_q, ref_grad_k, ref_grad_v = query.grad, key.grad, value.grad
query.grad = key.grad = value.grad = None
atol = op_fw.ERROR_ATOL[dtype]
rtol = op_fw.ERROR_RTOL[dtype]
assert_allclose(out.float(), ref.float(), "out", atol=atol, rtol=rtol)
atol = op_bw.ERROR_ATOL[dtype]
rtol = op_bw.ERROR_RTOL[dtype]
assert_allclose(grad_q, ref_grad_q, "grad_q", atol=atol, rtol=rtol)
assert_allclose(grad_k, ref_grad_k, "grad_k", atol=atol, rtol=rtol)
assert_allclose(grad_v, ref_grad_v, "grad_v", atol=atol, rtol=rtol)
def apply_attention(query, key, value, attn_bias, op_fw, proj):
x = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attn_bias, op=(op_fw, None)
)
x = proj(x)
return x
@pytest.mark.parametrize("use_reentrant", [False, True])
@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs
def test_grad_checkpointing(
opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
use_reentrant,
):
fmt = "BMHK"
(
op,
device,
dtype,
bias_type,
batch_size,
q_len,
kv_len,
h,
k,
kv,
) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv
if op is fmha.triton_splitk.FwOp:
pytest.skip("Triton Flash Decoding doesn't support backward pass yet")
if op is fmha.ck.FwOp:
pytest.skip("ck-tiled FMHA doesn't supported backward pass yet")
bias_type = None
opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv = (
op,
device,
dtype,
bias_type,
batch_size,
q_len,
kv_len,
h,
k,
kv,
)
query, key, value, attn_bias = create_tensors(
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
fmt=fmt,
)
qkv = None
if (
fmt == "BMHK"
and query.shape[3] == value.shape[3]
and query.shape[1] == value.shape[1]
):
qkv = torch.stack([query, key, value], 2)
qkv.requires_grad_(True)
# bm3hk -> 3 x bmhk
query, key, value = xformers.ops.unbind(qkv, 2)
assert not query.is_contiguous()
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
proj = torch.nn.Linear(kv, k, device=device, dtype=dtype)
x = query
for _ in range(5):
x = checkpoint(
apply_attention,
x,
key,
value,
attn_bias,
op,
proj,
use_reentrant=use_reentrant,
)
x.mean().backward()
ALL_FW_OPS_NO_SMALLK = [op for op in ALL_FW_OPS if op is not fmha.small_k.FwOp]
@pytest.mark.parametrize(
"op", ALL_FW_OPS_NO_SMALLK, ids=[op.NAME for op in ALL_FW_OPS_NO_SMALLK]
)
def test_unsupported_cpu(op: Type[fmha.AttentionFwOpBase]):
q = torch.empty([1, 1, 1, 32])
with pytest.raises(ValueError):
fmha.memory_efficient_attention(q, q, q, op=(op, None))
@cuda_only
@pytest.mark.parametrize(
"op", ALL_FW_OPS_NO_SMALLK, ids=[op.NAME for op in ALL_FW_OPS_NO_SMALLK]
)
def test_unsupported_stride_lastdim(op: Type[fmha.AttentionFwOpBase]):
q = torch.empty([1, 1, 32, 4], device="cuda", dtype=torch.float16).permute(
0, 3, 1, 2
)
try:
fmha.memory_efficient_attention(q, q, q, op=(op, None))
except ValueError as e:
if "Only work on pre-MLIR triton for now" in str(e):
pytest.skip("Only work on pre-MLIR triton for now")
q = q.contiguous()
fmha.memory_efficient_attention(q, q, q, op=(op, None))
@cuda_only
@pytest.mark.parametrize(
"op", ALL_FW_OPS_NO_SMALLK, ids=[op.NAME for op in ALL_FW_OPS_NO_SMALLK]
)
def test_unsupported_stride_alignment(op: Type[fmha.AttentionFwOpBase]):
q = torch.empty([1, 2, 1, 33], device="cuda", dtype=torch.float16)[:, :, :, :32]
try:
fmha.memory_efficient_attention(q, q, q, op=(op, None))
except ValueError as e:
if "Only work on pre-MLIR triton for now" in str(e):
pytest.skip("Only work on pre-MLIR triton for now")
q = q.contiguous()
fmha.memory_efficient_attention(q, q, q, op=(op, None))
@sm75_or_better_only
def test_unsupported_dropout_combine_flash_cutlass() -> None:
q = torch.empty(
[1, 4, 1, 16], device="cuda", dtype=torch.float16, requires_grad=True
)
with pytest.raises(ValueError):
out = fmha.memory_efficient_attention(
q, q, q, p=0.1, op=(fmha.cutlass.FwOp, fmha.flash.BwOp)
)
out.backward(out)
with pytest.raises(ValueError):
out = fmha.memory_efficient_attention(
q, q, q, p=0.1, op=(fmha.flash.FwOp, fmha.cutlass.BwOp)
)
out.backward(out)
def test_attn_bias_causal() -> None:
m = -math.inf
causal_mask = torch.tensor([[0, m], [0, 0], [0, 0]])
tensor_bias = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
attn_bias = fmha.attn_bias.LowerTriangularMask()
assert_allclose(attn_bias.materialize(causal_mask.shape), causal_mask, "causal")
attn_bias = attn_bias.add_bias(tensor_bias)
assert_allclose(
attn_bias.materialize(causal_mask.shape),
tensor_bias + causal_mask,
"causal+tensor_bias",
)
def test_attn_bias_torch_tensor() -> None:
tensor_bias = torch.tensor([[1.0, 2.0, 3.0], [3.0, 4.0, 5.0]])
attn_bias = fmha.attn_bias.LowerTriangularMaskWithTensorBias(tensor_bias)
m = -math.inf
causal_bias = torch.tensor([[0, m, m], [0, 0, m]])
assert_allclose(
attn_bias.materialize((2, 3)), causal_bias + tensor_bias, "tensor_bias+causal"
)
def test_attn_bias_blockdiag() -> None:
queries = [
torch.randn([1, 3, 1, 8]),
torch.randn([1, 2, 1, 8]),
torch.randn([1, 5, 1, 8]),
]
attn_bias, q = fmha.BlockDiagonalMask.from_tensor_list(queries)
# Verify mask
as_tensor = attn_bias.materialize((10, 10))
assert int((as_tensor != -math.inf).sum().item()) == 3 * 3 + 2 * 2 + 5 * 5
assert_allclose(as_tensor[0:3, 0:3], torch.zeros([3, 3]), "batch0")
assert_allclose(as_tensor[3:5, 3:5], torch.zeros([2, 2]), "batch1")
assert_allclose(as_tensor[5:, 5:], torch.zeros([5, 5]), "batch2")
# Verify we can split it back
queries2 = attn_bias.split(q)
assert len(queries) == len(queries2)
for q1, q2 in zip(queries, queries2):
assert_allclose(q1, q2)
def test_attn_bias_blockdiag_batched() -> None:
queries = [
torch.randn([1, 3, 1, 8]),
torch.randn([3, 2, 1, 8]),
torch.randn([1, 5, 1, 8]),
]
attn_bias, q = fmha.BlockDiagonalMask.from_tensor_list(queries)
# Verify mask
as_tensor = attn_bias.materialize((14, 14))
assert int((as_tensor != -math.inf).sum().item()) == 3 * 3 + 3 * 2 * 2 + 5 * 5
assert_allclose(as_tensor[0:3, 0:3], torch.zeros([3, 3]), "batch0")
assert_allclose(as_tensor[3:5, 3:5], torch.zeros([2, 2]), "batch1.0")
assert_allclose(as_tensor[5:7, 5:7], torch.zeros([2, 2]), "batch1.1")
assert_allclose(as_tensor[7:9, 7:9], torch.zeros([2, 2]), "batch1.2")
assert_allclose(as_tensor[9:, 9:], torch.zeros([5, 5]), "batch2")
# Verify we can split it back
queries2 = attn_bias.split(q)
assert len(queries) == len(queries2)
for q1, q2 in zip(queries, queries2):
assert_allclose(q1, q2)
def test_attn_bias_blockdiag_crossattn_causal() -> None:
# Q / KV have different seqlen
list_q = [
torch.randn([1, 3, 1, 8]),
torch.randn([2, 1, 1, 8]),
]
list_k = [
torch.randn([1, 2, 1, 8]),
torch.randn([2, 3, 1, 8]),
]
attn_bias, q, k, _ = fmha.attn_bias.BlockDiagonalMask.from_tensor_lists_qkv(
list_q, list_k
)
# Verify mask
as_tensor = attn_bias.materialize((q.shape[1], k.shape[1]))
assert int((as_tensor != -math.inf).sum().item()) == 3 * 2 + 2 * 3 * 1
assert_allclose(as_tensor[0:3, 0:2], torch.zeros([3, 2]), "batch0")
assert_allclose(as_tensor[3:4, 2:5], torch.zeros([1, 3]), "batch1.0")
assert_allclose(as_tensor[4:, 5:], torch.zeros([1, 3]), "batch1.1")
# Also test causal version
as_tensor = attn_bias.make_causal().materialize((q.shape[1], k.shape[1]))
assert_allclose(
as_tensor[3:4, 2:5],
fmha.attn_bias.LowerTriangularMask().materialize((1, 3)),
"batch1.0[causal]",
)
# Verify we can split it back
list_q2 = attn_bias.split_queries(q)
assert len(list_q) == len(list_q2)
for q1, q2 in zip(list_q, list_q2):
assert_allclose(q1, q2)
with pytest.raises(ValueError):
attn_bias.split_queries(k)
list_k2 = attn_bias.split_kv(k)
assert len(list_k) == len(list_k2)
for k1, k2 in zip(list_k, list_k2):
assert_allclose(k1, k2)
def test_attn_bias_blockdiag_crossattn_causal_with_prefix_qk_cond() -> None:
list_q = [
torch.randn([1, 3, 1, 8]),
]
list_k = [
torch.randn([1, 2, 1, 8]),
]
attn_bias, q, k, _ = fmha.attn_bias.BlockDiagonalMask.from_tensor_lists_qkv(
list_q, list_k
)
with pytest.raises(ValueError):
attn_bias.make_causal_from_bottomright()
def test_attn_bias_blockdiag_crossattn_causal_with_prefix() -> None:
# Q / KV have different seqlen
list_q = [
torch.randn([1, 2, 1, 8]),
torch.randn([2, 2, 1, 8]),
]
list_k = [
torch.randn([1, 2, 1, 8]),
torch.randn([2, 5, 1, 8]),
]
attn_bias, q, k, _ = fmha.attn_bias.BlockDiagonalMask.from_tensor_lists_qkv(
list_q, list_k
)
as_tensor = attn_bias.make_causal_from_bottomright().materialize(
(q.shape[1], k.shape[1])
)
m = -math.inf
assert_allclose(
as_tensor[0:2, 0:2],
torch.tensor([[0, m], [0, 0]], dtype=torch.float32),
"batch1.1[causal_with_prefix]",
)
assert_allclose(
as_tensor[2:4, 2:7],
torch.tensor([[0, 0, 0, 0, m], [0, 0, 0, 0, 0]], dtype=torch.float32),
"batch2.1[causal_with_prefix]",
)
assert_allclose(
as_tensor[4:6, 7:12],
torch.tensor([[0, 0, 0, 0, m], [0, 0, 0, 0, 0]], dtype=torch.float32),
"batch2.2[causal_with_prefix]",
)
@cuda_only
def test_attn_bias_padded() -> None:
bsize, n_heads, d, padding = 8, 3, 8, 32
torch.manual_seed(0)
# Q / KV have different seqlen
k = torch.randn((bsize, padding, n_heads, d), device="cuda", dtype=torch.float16)
k_seqlen = [5, 8, 7, 1, 9, 3, 12, 32]
other = bsize - 1
v = torch.randn((bsize, padding, n_heads, d), device="cuda", dtype=torch.float16)
n_q_first = 4
q = [
torch.randn((1, n_q_first, n_heads, d), device="cuda", dtype=torch.float16),
torch.randn((1, other, n_heads, d), device="cuda", dtype=torch.float16),
]
q_cat = torch.cat([x.view(1, -1, n_heads, d) for x in q], dim=1)
q_seqlen = [n_q_first] + [1] * other
attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=q_seqlen,
kv_seqlen=k_seqlen,
kv_padding=padding,
)
v = v.view(1, -1, n_heads, d)
k = k.view(1, -1, n_heads, d)
scores = (q_cat.transpose(1, 2) @ k.transpose(1, 2).transpose(2, 3)).float()
assert not scores.isnan().any()
mask = torch.full_like(scores, -float("inf"))
for i, (slen, qlen) in enumerate(zip(k_seqlen, q_seqlen)):
kseq_start = i * padding
qstart = sum(q_seqlen[:i])
mask[:, :, qstart : qstart + qlen, kseq_start : kseq_start + slen] = torch.triu(
mask[:, :, qstart : qstart + qlen, kseq_start : kseq_start + slen].float(),
diagonal=1 + slen - qlen,
).float()
scores += mask
assert not scores.isnan().any()
# 1,3,10,8 @ 1,3,8,256 -> 1,3,10,256
scores = torch.nn.functional.softmax(scores, -1).half()
# torch.Size([1, 3, 3, 32]) @ torch.Size([1, 3, 32, 8])
output = scores @ v.transpose(1, 2) # 1,3,10,256 @ 1,3,256, 8 -> 1,3,10,8
output = output.transpose(1, 2).contiguous()
fmha_output = fmha.memory_efficient_attention_forward(
q_cat, k, v, attn_bias, scale=1.0
)
# assert torch.allclose(output, fmha_output)
assert_allclose(
output,
fmha_output,
atol=fmha.cutlass.FwOp.ERROR_ATOL[torch.float16],
rtol=fmha.cutlass.FwOp.ERROR_RTOL[torch.float16],
)
@cuda_only
def test_attn_bias_to_copy() -> None:
def _test_to_copy(attn_bias: torch.Tensor) -> None:
assert attn_bias.device.type == "cpu", f"{attn_bias.device}"
attn_bias_cuda = attn_bias.cuda()
assert attn_bias_cuda.device.type == "cuda", f"{attn_bias_cuda.device}"
attn_bias_fp16 = attn_bias.to(torch.float16)
assert attn_bias_fp16.device.type == "cpu", f"{attn_bias_fp16.device}"
assert attn_bias_fp16.dtype == torch.float16, f"{attn_bias_fp16.dtype}"
attn_bias = fmha.attn_bias.LowerTriangularMask().to("cpu")
_test_to_copy(attn_bias)
tensor_bias = torch.tensor([[1.0, 2.0, 3.0], [3.0, 4.0, 5.0]])
attn_bias = fmha.attn_bias.LowerTriangularMaskWithTensorBias(tensor_bias).to("cpu")
_test_to_copy(attn_bias)
def _kv_heads_label(kv_heads: Optional[int]) -> str:
if kv_heads is None:
return ""
if kv_heads == 1:
return "mq"
return f"gqa{kv_heads}"
@sm70_or_better_only
@pytest.mark.parametrize(
"op",
[
fmha.decoder.FwOp if torch.version.cuda else fmha.ck_decoder.FwOp,
],
)
@pytest.mark.parametrize("kv_heads", [None, 1, 2], ids=_kv_heads_label)
@pytest.mark.parametrize("bsz,n_heads", [(1, 1), (1, 16), (1, 32), (8, 1), (4, 8)])
@pytest.mark.parametrize("padding", [32, 4096])
@pytest.mark.parametrize("dtype", ["f16", "bf16", "f32"])
def test_decoder(
op,
n_heads: int,
kv_heads: Optional[int],
padding: int,
bsz: int,
dtype: str,
dequant: bool = False,
num_queries: int = 1,
d: int = 128,
) -> None:
# kv_heads = 1: multiquery
# kv_heads = None: neither MQA nor GQA
# kv_heads > 1: BMGHK
if dtype == "bf16" and compute_capability < (8, 0):
raise pytest.skip("BF16 is only supported on SM80+")
import triton
if dequant and triton.__version__[:4] < "3.0.":
raise pytest.skip("dequant needs triton updates")
dtype_ = {"f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32}[dtype]
torch.manual_seed(1)
if kv_heads is not None and kv_heads > 1:
k_shape: Tuple[int, ...] = (1, bsz * padding, kv_heads, n_heads, d)
q_shape: Tuple[int, ...] = (
1,
bsz * num_queries,
kv_heads,
n_heads,
d,
)
else:
k_shape = (1, bsz * padding, n_heads, d)
q_shape = (1, bsz * num_queries, n_heads, d)
# TODO: support 2 kv heads etc.
k = torch.randn(k_shape, dtype=dtype_, device="cuda")
k_seqlen = torch.randint(num_queries, padding + 1, (bsz,)).tolist()
v = torch.randn(k_shape, dtype=dtype_, device="cuda")
q = torch.randn(q_shape, dtype=dtype_, device="cuda")
if dequant:
k_shape = k_shape[:-1] + (d // 8 + op.NUM_GROUPS,)
k = torch.zeros(k_shape, dtype=torch.int32, device="cuda")
k.random_()
k[..., : op.NUM_GROUPS].view(torch.float16).fill_(1.0)
v = torch.zeros(k_shape, dtype=torch.int32, device="cuda")
v.random_()
v[..., : op.NUM_GROUPS].view(torch.float16).fill_(1.0)
if kv_heads is not None:
k = k[..., :1, :].expand(k_shape)
v = v[..., :1, :].expand(k_shape)
if skip_reasons := op.not_supported_reasons(fmha.Inputs(q, k, v)):
pytest.skip("; ".join(skip_reasons))
attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=[num_queries] * bsz,
kv_seqlen=k_seqlen,
kv_padding=padding,
)
decoder_output = fmha.memory_efficient_attention_forward(
q,
k,
v,
attn_bias,
op=op,
)
def dequant_cache(x):
x = x[..., op.NUM_GROUPS :, None].expand(k_shape[:-1] + (d // 8, 8))
x = x // (2 ** (4 * torch.arange(8, device="cuda")))
x = (x % 16).flatten(start_dim=-2)
return x.to(dtype_) + 1.0
if dequant:
k = dequant_cache(k)
v = dequant_cache(v)
ref_output = ref_attention_for_test(q, k, v, attn_bias)
assert_allclose(
decoder_output.to(ref_output.dtype),
ref_output,
atol=op.ERROR_ATOL[dtype_] * 4,
rtol=op.ERROR_RTOL[dtype_],
)
@sm80_or_better_only
@pytest.mark.parametrize(
"op,dequant,dtype",
[
(fmha.triton_splitk.FwOp_S1, False, "bf16"),
(fmha.triton_splitk.FwOp_S2, False, "f16"),
(fmha.triton_splitk.FwOp_S2, True, "bf16"),
(
type(
"S2_8", (fmha.triton_splitk.FwOp_S2,), {"NUM_GROUPS": 8, "NAME": "S2_8"}
),
True,
"bf16",
),
],
)
@pytest.mark.parametrize("kv_heads", [None, 1, 2], ids=_kv_heads_label)
@pytest.mark.parametrize("n_heads", [16])
@pytest.mark.parametrize("padding, bsz", [(32, 8), (4096, 1)])
def test_triton_splitk_decoder(
op,
dequant: bool,
kv_heads: Optional[int],
n_heads: int,
padding: int,
bsz: int,
dtype: str,
) -> None:
# We omit dequant with f16: it needs a very high tol
test_decoder(
op,
kv_heads=kv_heads,
n_heads=n_heads,
padding=padding,
bsz=bsz,
dtype=dtype,
dequant=dequant,
)
@rocm_only
@pytest.mark.parametrize(
"op", [fmha.ck_splitk.FwOp_S1, fmha.ck_splitk.FwOp_S2, fmha.ck_splitk.FwOp_S4]
)
@pytest.mark.parametrize("dtype", ["f32"])
@pytest.mark.parametrize("kv_heads", [None, 1, 2], ids=_kv_heads_label)
@pytest.mark.parametrize("n_heads", [16])
@pytest.mark.parametrize("d", [128, 256])
@pytest.mark.parametrize("padding, bsz", [(32, 8), (4096, 1), (32, 1), (4096, 8)])
def test_ck_splitk_decoder(
op,
kv_heads: Optional[int],
n_heads: int,
padding: int,
bsz: int,
dtype: str,
d: int,
) -> None:
# no quantized impl compared to cuda
test_decoder(
op,
kv_heads=kv_heads,
n_heads=n_heads,
padding=padding,
bsz=bsz,
dtype=dtype,
d=d,
)
@sm80_or_better_only
@pytest.mark.parametrize(
"op",
[
fmha.triton_splitk.FwOp_S1,
fmha.triton_splitk.FwOp_S2,
],
ids=lambda op: f"splitk{op.SPLIT_K}",
)
@pytest.mark.parametrize("multiquery", [True, False], ids=lambda x: "mq" if x else "")
# n_heads=1 => it's ambiguous whether can count as multiquery
@pytest.mark.parametrize("padding, bsz", [(32, 8), (44, 1)])
@pytest.mark.parametrize("dtype", ["f16", "bf16"])
@pytest.mark.parametrize("n_heads, num_queries", [(2, 4), (2, 5), (6, 7), (20, 3)])
def test_triton_splitk_decoder_manyqueries(
op,
multiquery: bool,
n_heads: int,
padding: int,
bsz: int,
dtype: str,
num_queries: int,
) -> None:
kv_heads = 1 if multiquery else None
test_decoder(
op,
kv_heads=kv_heads,
n_heads=n_heads,
padding=padding,
bsz=bsz,
dtype=dtype,
num_queries=num_queries,
dequant=False,
)
def test_attn_bias_from_seqlens() -> None:
bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens([3, 5, 1])
out = bias.split(torch.randn([1, 3 + 5 + 1, 16]))
assert len(out) == 3
assert tuple(out[0].shape) == (1, 3, 16)
@cuda_only
def test_attn_bias_blockdiag_doc() -> None:
"""IMPORTANT:
This is the example in the doc for `BlockDiagonalMask`.
If this example needs to be updated, please also update the doc
"""
import torch
from xformers.ops import fmha
if torch.version.hip:
pytest.skip("backward pass/gradience is not yet supported by ck-tiled fmha!")
K = 16
dtype = torch.float16
device = "cuda"
list_x = [
torch.randn([1, 3, 1, K], dtype=dtype, device=device),
torch.randn([1, 6, 1, K], dtype=dtype, device=device),
torch.randn([1, 2, 1, K], dtype=dtype, device=device),
]
attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x)
linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype) # type: ignore
q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2)
out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias)
list_out = attn_bias.split(out)
assert tuple(list_out[0].shape) == (1, 3, 1, K)
@cuda_only
class TestAttnBias:
@staticmethod
def create_tensors(
dtype,
B: int = 2,
Mq: int = 32,
Mkv: int = 32,
H: int = 3,
K: int = 16,
Kv: int = 16,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return (
torch.randn([B, Mq, H, K], device="cuda", dtype=dtype) * 3,
torch.randn([B, Mkv, H, K], device="cuda", dtype=dtype) * 3,
torch.randn([B, Mkv, H, Kv], device="cuda", dtype=dtype) * 3,
torch.randn([B, H, Mq, Mkv], device="cuda", dtype=dtype) * 3,
)
@staticmethod
def pad_bias(bias: torch.Tensor) -> torch.Tensor:
align_to = 16
if (bias.shape[-1] % align_to) == 0:
return bias
pad_count = align_to - (bias.shape[-1] % align_to)
return torch.nn.functional.pad(bias, [0, pad_count])[:, :, :, : bias.shape[-1]]
def test_f16_biasf32(self) -> None:
q, k, v, bias = self.create_tensors(torch.float16)
fmha.memory_efficient_attention(q, k, v, attn_bias=bias)
bias = bias.to(torch.float32)
with pytest.raises((ValueError, RuntimeError)):
fmha.memory_efficient_attention(q, k, v, attn_bias=bias)
@disable_on_rocm
def test_f32_biasf16(self) -> None:
q, k, v, bias = self.create_tensors(torch.float32)
fmha.memory_efficient_attention(q, k, v, attn_bias=bias)
bias = bias.to(torch.float16)
with pytest.raises((ValueError, RuntimeError)):
fmha.memory_efficient_attention(q, k, v, attn_bias=bias)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_wrong_alignment(self, dtype) -> None:
op = fmha.cutlass.FwOp if torch.version.cuda else fmha.ck.FwOp
if dtype not in op.SUPPORTED_DTYPES:
pytest.skip(
f"{dtype=} is not supported by {op.__module__}.{op.__qualname__}"
)
q, k, v, bias = self.create_tensors(dtype, Mq=7, Mkv=5)
try:
fmha.memory_efficient_attention(q, k, v, attn_bias=bias, op=(op, None))
return
except (ValueError, RuntimeError):
pass
# This case is not supported, likely due to padding issues
# Let's make sure it works with padding
assert bias.ndim == 4, bias.shape
bias_padded = self.pad_bias(bias)
out = fmha.memory_efficient_attention(
q, k, v, attn_bias=bias_padded, op=(op, None)
).float()
ref_out = ref_attention_bmhk_for_test(q, k, v, bias)
assert_allclose(
out, ref_out, atol=op.ERROR_ATOL[dtype], rtol=op.ERROR_RTOL[dtype]
)
def test_permuted_attn_bias(self) -> None:
op = fmha.cutlass.FwOp
dtype = torch.float16
q, k, v, bias = self.create_tensors(dtype, Mq=7, Mkv=7)
bias = bias.transpose(-1, -2) # now `stride(-1) != 1`
# Either it works, or it raises an exception
# but we should never get a CUDA error
try:
out = fmha.memory_efficient_attention(
q, k, v, attn_bias=bias, op=(op, None)
).float()
ref_out = ref_attention_bmhk_for_test(q, k, v, bias)
assert_allclose(
out, ref_out, atol=op.ERROR_ATOL[dtype], rtol=op.ERROR_RTOL[dtype]
)
except (ValueError, RuntimeError):
pass
SM_AND_SHMEM_KBYTES = [
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications-technical-specifications-per-compute-capability
(50, 64),
(60, 64),
(70, 96),
(75, 64),
(80, 163),
(86, 99),
(89, 99),
# (90, 227),
]
@cuda_only
@disable_on_rocm
@skip_if_pt_cutlass
@pytest.mark.parametrize("dtype_str", ["f32", "f16", "bf16"])
@pytest.mark.parametrize(
"sm_shmem",
SM_AND_SHMEM_KBYTES,
ids=[f"cc{sm}_shmem{shmem}kb" for sm, shmem in SM_AND_SHMEM_KBYTES],
)
def test_has_kernel_for(sm_shmem: Tuple[int, int], dtype_str: str) -> None:
dtype = {"f32": torch.float, "f16": torch.half, "bf16": torch.bfloat16}[dtype_str]
sm, shmem_kbytes = sm_shmem
if sm < 80 and dtype_str == "bf16":
return
for k in [16, 32, 64, 128, 256]:
assert torch.ops.xformers._has_cutlassF_kernel_for(
dtype, sm, shmem_kbytes * 1024, k
), f"k={k}"
assert torch.ops.xformers._has_cutlassB_kernel_for(
dtype, sm, shmem_kbytes * 1024, k
), f"k={k}"
def test_window_size_materialize() -> None:
seqlens = [4, 6]
attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens(
q_seqlen=seqlens,
kv_seqlen=seqlens,
).make_local_attention(2)
mask = attn_bias.materialize(
(1, 1, sum(seqlens), sum(seqlens)),
device="cpu",
dtype=torch.float32,
)
true_mask = torch.log(
torch.Tensor(
[
[
[
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
]
]
]
)
)
assert torch.all(mask == true_mask)
@cuda_only
@pytest.mark.parametrize("Mq", [1, 512])
@pytest.mark.parametrize(
"opFW_biasT",
[
(op, biasT)
for op in ALL_FW_OPS_NO_UNPADDED_LSE
for biasT in op.SUPPORTED_ATTN_BIAS_TYPES
if op.SUPPORTS_BMGHK
],
ids=lambda o: f"{o[0].NAME}-{o[1].__name__}" if isinstance(o, tuple) else "",
)
def test_forward_gqa(opFW_biasT, Mq: int):
opFW, biasT = opFW_biasT
if Mq < 512 and (
issubclass(biasT, fmha.attn_bias.LowerTriangularMask)
or issubclass(biasT, fmha.attn_bias.BlockDiagonalCausalMask)
):
pytest.skip("undefined upper left")
B_Mq_Mkv_H_K_Kv = (3, Mq, 512, 16, 128, 128)
test_forward(
(
opFW,
"cuda",
torch.float16,
biasT,
*B_Mq_Mkv_H_K_Kv,
),
packed=False,
fmt="BMGHK",
g=2,
)
@cuda_only
@pytest.mark.parametrize(
"opBW",
[
fmha.flash.BwOp,
fmha.cutlass.BwOp,
],
)
def test_backward_gqa(opBW):
H = 8
B_Mq_Mkv_H_K_Kv = (3, 512, 512, H, 128, 128)
dtype = torch.float16
query, key, value, attn_bias = create_tensors(
*(opBW, "cuda", dtype, type(None), *B_Mq_Mkv_H_K_Kv),
attn_bias_requires_grad=False,
fmt="BMHK",
)
op = (fmha.cutlass.FwOp, opBW)
key = key[:, :, :1].expand(-1, -1, H, -1)
value = value[:, :, :1].expand(-1, -1, H, -1)
key.requires_grad_(True)
out = fmha.memory_efficient_attention(query, key, value, attn_bias=attn_bias)
out_ref = ref_attention_bmhk_for_test(query, key, value, attn_bias=attn_bias)
assert_allclose(
out.float(),
out_ref.float(),
atol=op[0].ERROR_ATOL[dtype],
rtol=op[0].ERROR_RTOL[dtype],
)
out.backward(query)
dk = key.grad
key.grad = None
out_ref.backward(query)
assert_allclose(
dk.float(),
key.grad.float(),
atol=op[1].ERROR_ATOL[dtype],
rtol=op[1].ERROR_RTOL[dtype],
)
@cuda_only
@pytest.mark.parametrize(
"opFW", [op for op in ALL_FW_OPS_NO_UNPADDED_LSE if op.SUPPORTS_BMGHK]
)
def test_forward_gqa_one_group(opFW):
dtype = torch.float16
B, Mq, Mkv, H, K = 3, 13, 16, 5, 128
q = torch.randn([B, Mq, 1, H, K], dtype=dtype, device="cuda") * 3
k = torch.randn([B, Mkv, 1, H, K], dtype=dtype, device="cuda") * 3
v = torch.randn([B, Mkv, 1, H, K], dtype=dtype, device="cuda") * 3
supported = opFW.supports(fmha.Inputs(q, k, v))
if not supported:
supported_bmhk = opFW.supports(fmha.Inputs(q[:, :, 0], k[:, :, 0], v[:, :, 0]))
assert supported == supported_bmhk
pytest.skip("not supported")
out = fmha.memory_efficient_attention_forward(q, k, v, op=opFW)
ref = ref_attention_for_test(q, k, v)
assert_allclose(
out.float(),
ref,
atol=opFW.ERROR_ATOL[dtype],
rtol=opFW.ERROR_RTOL.get(dtype, 1e-5),
)
@sm80_or_better_only
@disable_on_rocm
def test_flash_gqa_wrong_strides() -> None:
op = (fmha.flash.FwOp, None)
device = "cuda"
B, Mq, Mkv, G, H, K = 3, 1, 512, 2, 8, 128
q = torch.empty((B, Mq, G, H, K), dtype=torch.float16, device=device)
kv = torch.empty((B, Mkv, G, H, K), dtype=torch.float16, device=device)
fmha.memory_efficient_attention(q, kv, kv, op=op)
kv = torch.empty((B, Mkv, H, G, K), dtype=torch.float16, device=device).permute(
0, 1, 3, 2, 4
)
with pytest.raises(ValueError):
fmha.memory_efficient_attention(q, kv, kv, op=op)
kv = torch.empty((B, Mkv, G, 1, K), dtype=torch.float16, device=device)
with pytest.raises(ValueError):
fmha.memory_efficient_attention(q, kv, kv, op=op)
kv = kv.expand(-1, -1, -1, H, K)
fmha.memory_efficient_attention(q, kv, kv, op=op)
kv = torch.empty((B, Mkv, G, H, 2 * K), dtype=torch.float16, device=device)[
:, :, :, :, :K
]
fmha.memory_efficient_attention(q, kv, kv, op=op)
def _dispatches_to_splitK(q, kv):
return (
_dispatch_fw_priority_list(fmha.Inputs(q, kv, kv), False)[0]
is fmha.triton_splitk.FwOp
)
def _dispatches_to_flash_decoding(q, kv):
return (
_dispatch_fw_priority_list(fmha.Inputs(q, kv, kv), False)[0] is fmha.flash.FwOp
)
@disable_on_rocm
def test_dispatch_decoding_bmhk() -> None:
assert not _dispatches_to_splitK(
torch.empty([1, 8, 1, 128]), torch.empty([1, 2048, 1, 128])
), "Should not use SplitK with 1 head (no tensorcores)"
assert _dispatches_to_flash_decoding(
torch.empty([1, 8, 32, 128]),
torch.empty([1, 2048, 1, 128]).expand(-1, -1, 32, -1),
), "Should use Flash-Decoding with BMHK MQA"
assert not _dispatches_to_splitK(
torch.empty([1, 8, 32, 128]),
torch.empty([1, 2048, 32, 128]),
), "Should not use SplitK when no TensorCores"
assert not _dispatches_to_splitK(
torch.empty([1, 128, 32, 128]),
torch.empty([1, 2048, 1, 128]).expand(-1, -1, 32, -1),
), "Should not use SplitK if q seqlen is long"
assert not _dispatches_to_splitK(
torch.empty([128, 8, 32, 128]),
torch.empty([128, 2048, 1, 128]).expand(-1, -1, 32, -1),
), "Should not use SplitK if B is big"
@disable_on_rocm
def test_dispatch_decoding_bmghk() -> None:
assert not _dispatches_to_splitK(
torch.empty([1, 8, 1, 1, 128]), torch.empty([1, 2048, 1, 1, 128])
), "Should not use SplitK with 1 head (no tensorcores)"
assert _dispatches_to_flash_decoding(
torch.empty([1, 8, 1, 32, 128]),
torch.empty([1, 2048, 1, 1, 128]).expand(-1, -1, -1, 32, -1),
), "Should use Flash-Decoding with MQA"
assert _dispatches_to_flash_decoding(
torch.empty([1, 8, 4, 32, 128]),
torch.empty([1, 2048, 4, 1, 128]).expand(-1, -1, -1, 32, -1),
), "Should use Flash-Decoding with GQA"
assert not _dispatches_to_splitK(
torch.empty([1, 8, 1, 32, 128]),
torch.empty([1, 2048, 1, 32, 128]),
), "Should not use SplitK when no TensorCores"
assert not _dispatches_to_splitK(
torch.empty([1, 128, 1, 32, 128]),
torch.empty([1, 2048, 1, 1, 128]).expand(-1, -1, -1, 32, -1),
), "Should not use SplitK if q seqlen is long"
assert not _dispatches_to_splitK(
torch.empty([128, 8, 1, 32, 128]),
torch.empty([128, 2048, 1, 1, 128]).expand(-1, -1, -1, 32, -1),
), "Should not use SplitK if B is big"
shapes_triton_splitk = [
(1, 8, 2**16, 1, 128, 128),
(1, 4, 2**16, 1, 128, 128),
(1, 16, 2**16, 1, 128, 128),
(1, 16, 2**16, 1, 32, 32),
(1, 8, 1025, 1, 128, 128),
(2, 8, 4096, 1, 128, 128),
(10, 8, 2**16, 1, 128, 128),
(10, 15, 2**16, 1, 128, 128),
(1, 3, 2**16, 1, 128, 128),
(1, 3, 2**16 - 10, 1, 128, 128),
(2, 3, 73, 1, 128, 128),
(2, 7, 7328, 1, 128, 128),
(2, 7, 7328, 1, 120, 120),
(2, 7, 63, 1, 120, 120),
]
op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv_splitk = [
(fmha.triton_splitk.FwOp, "cuda", torch.float16, type(None), *s)
for s in shapes_triton_splitk
] + [
(fmha.triton_splitk.FwOp, "cuda", torch.bfloat16, type(None), *s)
for s in shapes_triton_splitk
]
@pytest.mark.parametrize(
"opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv",
op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv_splitk,
ids=[make_id(*c) for c in op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv_splitk],
)
@cuda_only
def test_forward_splitk(
opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
packed=False,
fmt="BMHK",
):
test_forward(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, packed=packed, fmt=fmt)
@cuda_only
@pytest.mark.parametrize(
"op",
[fmha.triton_splitk.FwOp, fmha.flash.FwOp, fmha.ck.FwOp],
ids=lambda op: op.NAME,
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=str)
@pytest.mark.parametrize(
"B_Mkv_H_K",
[
(1, 2**16, 3, 128),
(5, 53, 4, 64),
(7, 51, 4, 256),
(3, 51, 2, 512),
],
)
def test_mqa_decoding(op: Type[fmha.AttentionFwOpBase], dtype, B_Mkv_H_K):
B, Mkv, H, K = B_Mkv_H_K
q = torch.randn([B, 1, H, K], dtype=dtype, device="cuda") * 3
k = torch.randn([B, Mkv, 1, K], dtype=dtype, device="cuda") * 3
v = torch.randn([B, Mkv, 1, K], dtype=dtype, device="cuda") * 3
k = k.expand(-1, -1, H, -1)
v = v.expand(-1, -1, H, -1)
if skip_reasons := op.not_supported_reasons(fmha.Inputs(q, k, v)):
pytest.skip("; ".join(skip_reasons))
out = fmha.memory_efficient_attention_forward(q, k, v, op=op)
ref = ref_attention_for_test(q, k, v)
assert_allclose(
out.float(),
ref,
atol=op.ERROR_ATOL[dtype],
rtol=op.ERROR_RTOL.get(dtype, 1e-5),
)
@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs
def test_empty_tensors_empty_query(
opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
):
query, key, value, attn_bias = create_tensors(
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
fmt="BMHK",
)
opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0]
if torch.version.hip:
pytest.skip("backward pass/gradience is not yet supported by ck-tiled fmha!")
query = query[:, :0]
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
out = xformers.ops.memory_efficient_attention(query, key, value, op=(opFW, None))
assert out.shape[1] == 0
out.backward(out)
# dK/dV should be all zeros
assert_allclose(key.grad, torch.zeros_like(key.grad), "key.grad")
assert_allclose(value.grad, torch.zeros_like(value.grad), "value.grad")
@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs
def test_empty_tensors_empty_kv(
opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
):
query, key, value, attn_bias = create_tensors(
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
fmt="BMHK",
)
opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0]
if opFW == fmha.triton_splitk.FwOp:
pytest.skip("triton_splitk doesn't support empty kv")
if torch.version.hip:
pytest.skip("backward pass/gradience is not yet supported by ck-tiled fmha!")
key = key[:, :0]
value = value[:, :0]
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
out = xformers.ops.memory_efficient_attention(query, key, value, op=(opFW, None))
assert_allclose(out, torch.zeros_like(out), "out")
out.backward(out)
# dQ should be all zeros
assert_allclose(query.grad, torch.zeros_like(query.grad), "query.grad")
@parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs
def test_empty_tensors_empty_b(
opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
):
query, key, value, attn_bias = create_tensors(
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv,
fmt="BMHK",
)
opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0]
if torch.version.hip:
pytest.skip("backward pass/gradience is not yet supported by ck-tiled fmha!")
query, key, value = query[:0], key[:0], value[:0]
query.requires_grad_(True)
key.requires_grad_(True)
value.requires_grad_(True)
out = xformers.ops.memory_efficient_attention(query, key, value, op=(opFW, None))
out.backward(out)
def test_local_attn_bias() -> None:
mask = (
fmha.attn_bias.LocalAttentionFromBottomRightMask(window_left=1, window_right=2)
.materialize(shape=(4, 4))
.exp()
)
expected = torch.tensor(
[[1, 1, 1, 0], [1, 1, 1, 1], [0, 1, 1, 1], [0, 0, 1, 1]], dtype=torch.float32
)
assert (mask == expected).all().item()
@cuda_only
@disable_on_rocm
@skip_if_pt_cutlass
@pytest.mark.parametrize("cc", [60, 70, 80])
@pytest.mark.parametrize("maxK", [32, 64, 128, 256])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize(
"custom_mask_type",
[
fmha.cutlass._CustomMaskType.NoCustomMask,
fmha.cutlass._CustomMaskType.CausalFromTopLeft,
fmha.cutlass._CustomMaskType.CausalFromBottomRight,
],
)
@pytest.mark.parametrize("window_size", [0, 3, 300])
@pytest.mark.parametrize(
"num_queries,num_keys",
[
(30, 66),
(256, 256),
# Edge cases
(314, 320),
(32, 256),
(224, 226),
(5, 531),
(320, 332), # for win_size=300
# Others
(256, 62),
(256, 63),
(256, 64),
(256, 65),
(256, 66),
],
)
def test_cutlassB_iter_order(
dtype,
cc: int,
maxK: int,
num_queries: int,
num_keys: int,
custom_mask_type,
window_size,
) -> None:
"""
This tests some internals of the cutlassB kernel
We test the iteration across blocks of [queries, keys] to ensure
that we correctly:
* Iterate over all the blocks that should be iterated
* Do *not* iterate over blocks that are completely masked out
* Correctly compute the number of parallel blocks that will compute
the same block of dQ
.. and we test this across variable causal masks+local attention combinations
"""
if (
window_size > 0
and custom_mask_type == fmha.cutlass._CustomMaskType.NoCustomMask
):
pytest.skip("LocalAttention is only supported for causal")
get_iteration_data = partial(
torch.ops.xformers._cutlassB_iteration_data,
dtype=dtype,
cc=cc,
maxK=maxK,
num_queries=num_queries,
num_keys=num_keys,
custom_mask_type=custom_mask_type,
window_size=window_size,
)
bias = torch.zeros([num_queries, num_keys], dtype=torch.float32)
if custom_mask_type != fmha.cutlass._CustomMaskType.NoCustomMask:
bias = fmha.attn_bias._materialize_causal_mask(
(num_queries, num_keys),
dtype=torch.float32,
device="cpu",
window_size=None if window_size == 0 else window_size,
from_bottomright=(
custom_mask_type == fmha.cutlass._CustomMaskType.CausalFromBottomRight
),
)
block_queries, block_keys = get_iteration_data()[:2]
mask_pooled = (
F.max_pool2d(bias.unsqueeze(0), (block_queries, block_keys), ceil_mode=True)
== 0
).int()[0]
attn_computed = torch.zeros_like(mask_pooled)
for key_start in range(0, num_keys, block_keys):
it = 0
new_key_start = key_start
new_query_start = get_iteration_data(key_start=key_start)[2]
try:
expected_first_query = (
mask_pooled[:, key_start // block_keys].tolist().index(1)
* block_queries
)
assert (
new_query_start == expected_first_query
), f"Wrong first query for K={key_start}: {new_query_start} (expected {expected_first_query})"
except ValueError: # Nothing to compute in this column
pass
while new_key_start == key_start and new_query_start < num_queries:
query_start = new_query_start
attn_computed[query_start // block_queries, key_start // block_keys] += 1
# print(f"Compute [{query_start}, {key_start}]")
# Is there something to compute here?
assert mask_pooled[
query_start // block_queries, key_start // block_keys
].item(), "Computing a block that is not needed!"
new_query_start, new_key_start = get_iteration_data(
key_start=key_start, query_start=query_start
)[3:5]
it += 1
assert it < num_queries, ""
assert (attn_computed == mask_pooled)[
:, key_start // block_keys
].all(), "some blocks were not computed!"
# Now check that the number returned by `getNumParallelBlocksForQuery` is correct
for query_start in range(0, num_queries, block_queries):
num_parallel_blocks = get_iteration_data(
query_start=query_start, num_splits_key=num_keys
)[5]
num_actual = mask_pooled[query_start // block_queries].sum().item()
assert num_parallel_blocks == num_actual
@sm80_or_better_only
@pytest.mark.parametrize("B", [1, 5, 128])
@pytest.mark.parametrize("MAX_T", [64, 128, 2048, 4096, 8192])
@pytest.mark.parametrize(
"op",
[
fmha.triton_splitk.FwOp,
fmha.triton_splitk.FwOp_S8,
fmha.triton_splitk.FwOp_Map[48],
],
ids=lambda op: op.NAME,
)
@pytest.mark.parametrize("num_quant_groups", [0, 1, 8])
@pytest.mark.parametrize("page_size", [64, 128, 256])
@pytest.mark.parametrize("gappy", [False, True], ids=lambda x: "gappy" if x else "")
def test_paged_attention(
B,
MAX_T: int,
num_quant_groups: int,
page_size: int,
op: Type[AttentionFwOpBase],
gappy: bool,
):
paged_attention_run_inner(
B, MAX_T, num_quant_groups, page_size, op, bench=False, gappy=gappy
)
@sm80_or_better_only
@disable_on_rocm
@pytest.mark.parametrize("B", [1, 5, 128])
@pytest.mark.parametrize("MAX_T", [64, 128, 2048, 4096, 8192])
@pytest.mark.parametrize("page_size", [256])
def test_paged_attention_flash(B, MAX_T: int, page_size: int):
# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged
op = fmha.flash.FwOp
if (
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask
not in op.SUPPORTED_ATTN_BIAS_TYPES
):
pytest.skip("Not supported bias")
num_quant_groups = 0
paged_attention_run_inner(B, MAX_T, num_quant_groups, page_size, op, bench=False)
def paged_attention_run_inner(
B: int,
MAX_T: int,
num_quant_groups: int,
page_size: int,
op: Type[AttentionFwOpBase],
bench: bool,
gappy: bool = False,
) -> None:
import triton
torch.manual_seed(10)
TEST_WARMUP_MS = 500
TEST_RUN_MS = 5000
N_H_L = 8
N_KVH_L = 1
D_H = 128
D_H_KV = D_H // 8 + num_quant_groups if num_quant_groups else D_H
kv_seqlens = torch.randint(low=1, high=MAX_T + 1, size=(B,)).tolist()
# Paged attention requires k.shape[1] and v.shape[1] to be divisible by page_size, so pad
padded_per_row_len = ((MAX_T + page_size - 1) // page_size) * page_size
if gappy:
make_paged_kwargs = {
"paged_type": fmha.attn_bias.PagedBlockDiagonalGappyKeysMask,
"notional_padding": MAX_T,
}
attn_bias = fmha.attn_bias.BlockDiagonalGappyKeysMask.from_seqlens(
q_seqlen=[1] * B,
kv_seqstarts=list(range(0, MAX_T * (B + 1), MAX_T)),
kv_seqlen=kv_seqlens,
)
else:
make_paged_kwargs = {
"paged_type": fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
}
block_type = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask
attn_bias = block_type.from_seqlens( # type: ignore
q_seqlen=[1] * B,
kv_padding=MAX_T,
kv_seqlen=kv_seqlens,
)
q = torch.randn((B, 1, N_H_L, D_H), dtype=torch.bfloat16, device="cuda")
if num_quant_groups:
if triton.__version__[:4] < "3.0.":
raise pytest.skip("dequant needs triton updates")
# Using high=64 below, because with 256 both paged and non-paged paths
# will produce NaNs - probably some quantization coeffitions are NaNs
# after the bitwise cast.
cache_k = torch.randint(
0, 64, (B, MAX_T, N_KVH_L, D_H_KV * 4), dtype=torch.uint8, device="cuda"
)
cache_k = cache_k.view(dtype=torch.int32)
cache_v = torch.randint(
0, 64, (B, MAX_T, N_KVH_L, D_H_KV * 4), dtype=torch.uint8, device="cuda"
)
cache_v = cache_v.view(dtype=torch.int32)
op = type(
f"{op.__name__}_{num_quant_groups}",
(op,),
{"NUM_GROUPS": num_quant_groups},
)
else:
cache_k = torch.randn(
(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda"
)
cache_v = torch.randn_like(cache_k)
axq = q.view(1, B * 1, N_H_L, D_H)
axk = cache_k.view(1, B * MAX_T, N_KVH_L, D_H_KV).expand(
1, B * MAX_T, N_H_L, D_H_KV
)
axv = cache_v.view(1, B * MAX_T, N_KVH_L, D_H_KV).expand(
1, B * MAX_T, N_H_L, D_H_KV
)
k_cache_size_usual = axk.numel()
# First, create "wasteful" K/V cache, where every block in logical cache
# has a physical representation, even if there's nothing stored there
block_tables = torch.arange(
B * padded_per_row_len // page_size, device="cuda", dtype=torch.int32
).reshape(B, -1)
shape_padded = (B, padded_per_row_len, N_KVH_L, D_H_KV)
axk_padded = torch.empty(shape_padded, device=axk.device, dtype=axk.dtype)
axv_padded = torch.empty(shape_padded, device=axv.device, dtype=axv.dtype)
axk_padded[:, :MAX_T] = axk.view(B, -1, N_H_L, D_H_KV)[:, :, :1, :]
axv_padded[:, :MAX_T] = axv.view(B, -1, N_H_L, D_H_KV)[:, :, :1, :]
axk_padded = axk_padded.view(1, B * padded_per_row_len, N_KVH_L, D_H_KV)
axv_padded = axv_padded.view(1, B * padded_per_row_len, N_KVH_L, D_H_KV)
axk_padded = axk_padded.expand(-1, -1, N_H_L, -1)
axv_padded = axv_padded.expand(-1, -1, N_H_L, -1)
attn_bias_paged = attn_bias.make_paged(
block_tables=block_tables, page_size=page_size, **make_paged_kwargs # type: ignore
)
y_usual = fmha.memory_efficient_attention_forward(
axq,
axk,
axv,
attn_bias,
op=op,
)
if bench:
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
y_usual = fmha.memory_efficient_attention_forward(
axq,
axk,
axv,
attn_bias,
op=op,
)
t_ms = triton.testing.do_bench(
lambda g=g: g.replay(),
warmup=TEST_WARMUP_MS,
rep=TEST_RUN_MS,
)
logger.info(f"Non-paged attention took {t_ms * 1e3:.2f}us")
y_wasteful = fmha.memory_efficient_attention_forward(
axq,
axk_padded,
axv_padded,
attn_bias_paged,
op=op,
)
if bench:
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
y_wasteful = fmha.memory_efficient_attention_forward(
axq,
axk_padded,
axv_padded,
attn_bias_paged,
op=op,
)
t_ms = triton.testing.do_bench(
lambda g=g: g.replay(),
warmup=TEST_WARMUP_MS,
rep=TEST_RUN_MS,
)
logger.info(f"Paged attention with wasteful K/V-cache took {t_ms * 1e3:.2f}us")
torch.testing.assert_close(
y_wasteful,
y_usual,
atol=1.0e-2,
rtol=1.0e-2,
)
# Now let's create a "packed" K/V cache, where only meaniningful logical blocks are mapped to physical blocks
(block_tables, packed_cache_k, packed_cache_v) = pack_kv_cache(
cache_k,
cache_v,
kv_seqlens,
page_size,
)
attn_bias_paged = attn_bias.make_paged(
block_tables=block_tables, page_size=page_size, **make_paged_kwargs # type: ignore
)
axk = packed_cache_k.view(1, -1, N_KVH_L, D_H_KV).expand(1, -1, N_H_L, D_H_KV)
axv = packed_cache_v.view(1, -1, N_KVH_L, D_H_KV).expand(1, -1, N_H_L, D_H_KV)
k_cache_size_packed = axk.numel()
y_packed = fmha.memory_efficient_attention_forward(
axq,
axk,
axv,
attn_bias_paged,
op=op,
)
logger.info(
f"KV-cache size reduced by {(100 * (1 - k_cache_size_packed/k_cache_size_usual)):.2f}%"
)
torch.testing.assert_close(y_wasteful, y_packed)
# Let's swap two blocks, and adjust two corresponding entries in the block table. The result shouldn't change
i, j = 0, axk.shape[1] // page_size - 1
axk = axk[:, :, :1, :]
axv = axv[:, :, :1, :]
vals_i = axk[:, i * page_size : (i + 1) * page_size, :, :].clone()
vals_j = axk[:, j * page_size : (j + 1) * page_size, :, :].clone()
axk[:, i * page_size : (i + 1) * page_size, :, :] = vals_j
axk[:, j * page_size : (j + 1) * page_size, :, :] = vals_i
vals_i = axv[:, i * page_size : (i + 1) * page_size, :, :].clone()
vals_j = axv[:, j * page_size : (j + 1) * page_size, :, :].clone()
axv[:, i * page_size : (i + 1) * page_size, :, :] = vals_j
axv[:, j * page_size : (j + 1) * page_size, :, :] = vals_i
axk = axk.expand(-1, -1, N_H_L, -1)
axv = axv.expand(-1, -1, N_H_L, -1)
where_i = block_tables == i
where_j = block_tables == j
block_tables.masked_fill_(where_i, j)
block_tables.masked_fill_(where_j, i)
y_swapped = fmha.memory_efficient_attention_forward(
axq,
axk,
axv,
attn_bias_paged,
op=op,
)
if bench:
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
y_swapped = fmha.memory_efficient_attention_forward(
axq,
axk,
axv,
attn_bias_paged,
op=op,
)
t_ms = triton.testing.do_bench(
lambda g=g: g.replay(),
warmup=TEST_WARMUP_MS,
rep=TEST_RUN_MS,
)
logger.info(f"Paged attention with packed K/V-cache took {t_ms * 1e3:.2f}us")
torch.testing.assert_close(y_swapped, y_packed)
@disable_on_rocm
@sm80_or_better_only
@pytest.mark.parametrize(
"op",
[
fmha.triton_splitk.FwOp,
fmha.flash.FwOp,
None,
],
ids=lambda op: "None" if op is None else op.NAME,
)
@pytest.mark.parametrize("G,H", [(1, 11), (7, 1), (1, 1), (7, 11), (None, 11)])
@pytest.mark.parametrize(
"write_lse", (False, True), ids=lambda x: "write_lse" if x else ""
)
@pytest.mark.parametrize(
"stack_inputs", (False, True), ids=lambda x: "stack_inputs" if x else ""
)
def test_merge_attentions_nobias(
write_lse: bool,
stack_inputs: bool,
op: Type[AttentionFwOpBase],
G: Optional[int],
H: int,
):
"""
Merging the same attention twice shouldn't change anything.
This also tests the shape of the lse output of each permitted op.
"""
B, M, Mq, K = 13, 5, 3, 128
if op is None or torch.bfloat16 in op.SUPPORTED_DTYPES:
dtype = torch.bfloat16
else:
dtype = next(iter(op.SUPPORTED_DTYPES))
if G is None:
q = 3 * torch.rand(B, Mq, H, K, dtype=dtype, device="cuda")
k = (3 * torch.rand(B, M, 1, K, dtype=dtype, device="cuda")).expand(B, M, H, K)
v = (3 * torch.rand(B, M, 1, K, dtype=dtype, device="cuda")).expand(B, M, H, K)
else:
q = 3 * torch.rand(B, Mq, G, H, K, dtype=dtype, device="cuda")
k = (3 * torch.rand(B, M, G, 1, K, dtype=dtype, device="cuda")).expand(
B, M, G, H, K
)
v = (3 * torch.rand(B, M, G, 1, K, dtype=dtype, device="cuda")).expand(
B, M, G, H, K
)
out1, lse1 = fmha.memory_efficient_attention_partial(q, k, v, op=op)
assert out1.shape == q.shape
M_ceil = lse1.shape[-1]
assert M_ceil >= Mq
assert lse1.shape == (B, H, M_ceil) if G is None else (B, G, H, M_ceil)
lse1 = lse1[..., :Mq]
attn_chunks = [out1, out1]
lse_chunks = [lse1, lse1]
attn_chunks_ = torch.stack(attn_chunks) if stack_inputs else attn_chunks
lse_chunks_ = torch.stack(lse_chunks) if stack_inputs else lse_chunks
out, lse = fmha.merge_attentions(attn_chunks_, lse_chunks_, write_lse=write_lse) # type: ignore
assert out.shape == out1.shape
assert_allclose(out1, out, rtol=1e-3, atol=1e-3, msg="out")
if write_lse:
assert lse is not None
assert lse.shape[:-1] == lse1.shape[:-1]
assert_allclose(
lse1[..., :Mq] + math.log(2), lse[..., :Mq], rtol=1e-3, atol=1e-3, msg="lse"
)
else:
assert lse is None
@disable_on_rocm
@sm80_or_better_only
@pytest.mark.parametrize(
"op",
[
pytest.param(fmha.flash.FwOp, id="flashfwd"),
pytest.param((fmha.flash.FwOp, fmha.cutlass.BwOp), id="flashcutlass"),
# pytest.param((fmha.triton_splitk.FwOp, fmha.cutlass.BwOp), id="splitk"), # XXX
pytest.param(fmha.MemoryEfficientAttentionFlashAttentionOp, id="flash"),
None,
],
)
def test_merge_attentions_nobias_bwd(
op: Union[Type[AttentionFwOpBase], fmha.AttentionOp]
):
B, M, Mq, H, K = 13, 5, 5, 4, 128
dtype = torch.bfloat16
nparts = 3
torch.manual_seed(1)
q = 3 * torch.rand(B, Mq, H, K, dtype=dtype, device="cuda")
kv = [
[3 * (torch.rand(B, M, H, K, dtype=dtype, device="cuda")) for _ in range(2)]
for _ in range(nparts)
]
q = q.requires_grad_(True)
kv = [[j.requires_grad_(True) for j in i] for i in kv]
out_parts = [fmha.memory_efficient_attention_partial(q, k, v, op=op) for k, v in kv]
attn_split, lse_split = [list(x) for x in zip(*out_parts)]
out_merged = fmha.merge_attentions(attn_split, lse_split, write_lse=True)[0]
grad_out = torch.rand_like(q)
out_merged.backward(grad_out)
grad_q_out = q.grad
assert q.grad is not None
grad_kv_out = [[j.grad for j in i] for i in kv]
q = q.detach().requires_grad_(True)
kv = [[j.detach().requires_grad_(True) for j in i] for i in kv]
k2, v2 = [torch.cat([i[j] for i in kv], dim=1) for j in range(2)]
if op is None or isinstance(op, tuple):
full_op = op
else:
full_op = (op, None)
out_full = fmha.memory_efficient_attention(q, k2, v2, op=full_op) # type: ignore
out_full.backward(grad_out)
assert_allclose(
out_merged, out_full.to(out_merged.dtype), rtol=1e-2, atol=2e-2, msg="out"
)
atol = fmha.AttentionBwOpBase.ERROR_ATOL[dtype] * 1.5
rtol = fmha.AttentionBwOpBase.ERROR_RTOL[dtype]
assert_allclose(grad_q_out, q.grad, rtol=rtol, atol=atol, msg="qgrad")
for i in range(nparts):
for j in range(2):
assert_allclose(
grad_kv_out[i][j],
kv[i][j].grad,
rtol=rtol,
atol=atol,
msg=f"kvgrad {i} {j}",
)
@disable_on_rocm
@sm80_or_better_only
@pytest.mark.parametrize(
"dtype,op",
[
(torch.bfloat16, fmha.triton_splitk.FwOp_S1),
# Cutlass's LSE is not consistent
# (torch.float32, fmha.cutlass.FwOp),
(torch.bfloat16, fmha.flash.FwOp),
],
ids=lambda o: f"{o.NAME}" if hasattr(o, "NAME") else str(o),
)
@pytest.mark.parametrize("num_queries", [1])
@pytest.mark.parametrize("bmghk", [True, False], ids=lambda x: "bmghk" if x else "")
def test_partial_paged(
dtype: torch.dtype, op: Type[AttentionFwOpBase], num_queries: int, bmghk: bool
):
B = 128
N_H_L = 8
D_H = 128
page_size = 256
G = 2 if bmghk else 1
block_tables = torch.zeros((B, 1), dtype=torch.int32, device="cuda")
torch.manual_seed(1)
output_dtype = torch.float32 if op.SUPPORTS_OUTPUT_DTYPE else None
B_T = num_queries * B
q = torch.randn((1, B_T, G, N_H_L, D_H), dtype=dtype, device="cuda")
k = torch.randn((1, page_size, G, 1, D_H), dtype=dtype, device="cuda")
v = torch.randn_like(k)
k = k.expand(1, page_size, G, N_H_L, D_H)
v = v.expand(1, page_size, G, N_H_L, D_H)
if not bmghk:
q = q[:, :, 0]
k = k[:, :, 0]
v = v[:, :, 0]
attn_bias = (
fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=[num_queries] * B,
kv_seqlen=[1] + ([100] * (B - 1)),
page_size=page_size,
block_tables=block_tables,
)
)
if attn_bias not in op.SUPPORTED_ATTN_BIAS_TYPES:
pytest.skip("Not supported bias")
attn_chunk, lse_chunk = fmha.memory_efficient_attention_partial(
q,
k,
v,
attn_bias,
op=op,
output_dtype=output_dtype,
)
if bmghk:
assert attn_chunk.shape == (1, B_T, G, N_H_L, D_H)
assert lse_chunk.shape == (
1,
G,
N_H_L,
B_T,
), f"{lse_chunk.shape=}, {(1, G, N_H_L, B_T)=}"
else:
assert attn_chunk.shape == (1, B_T, N_H_L, D_H)
assert lse_chunk.shape == (
1,
N_H_L,
B_T,
), f"{lse_chunk.shape=}, {(1, N_H_L, B_T)=}"
@disable_on_rocm
@sm80_or_better_only
@pytest.mark.parametrize(
"dtype,op",
[
(torch.bfloat16, fmha.triton_splitk.FwOp_S1),
(torch.bfloat16, fmha.triton_splitk.FwOp_S32),
# Cutlass's LSE is not consistent
# (torch.float32, fmha.cutlass.FwOp),
(torch.bfloat16, fmha.flash.FwOp),
],
ids=lambda o: f"{o.NAME}" if hasattr(o, "NAME") else str(o),
)
@pytest.mark.parametrize("num_queries", [1, 2])
@pytest.mark.parametrize("bmghk", [True, False], ids=lambda x: "bmghk" if x else "")
@pytest.mark.parametrize(
"stack_inputs", (False, True), ids=lambda x: "stack_inputs" if x else ""
)
def test_merge_attentions_decoding(
dtype: torch.dtype,
op: Type[AttentionFwOpBase],
num_queries: int,
bmghk: bool,
stack_inputs: bool,
):
"""
Compute decoding attention on chunks of K/V and merge them together.
Compare with computing attention on the whole K/V.
"""
MAX_T = 8192
B = 128
N_H_L = 8
D_H = 128
G = 2 if bmghk else 1
torch.manual_seed(1)
output_dtype = torch.float32 if op.SUPPORTS_OUTPUT_DTYPE else None
num_chunks = 10
chunk_starts = sorted(
torch.randint(low=1, high=MAX_T // 2, size=(num_chunks,)).tolist()
)
chunk_starts[0] = 0
chunk_starts.append(MAX_T)
# We construct sequences so that even the last chunk has a non-empty part of every sequence
# as long as the number of queries.
# Otherwise the corresponding LSE will be -inf and that'll propagate to the whole sum.
# It is possible to teach the kernel to ignore infinite LSEs, but in practical use cases
# of merging attention, e.g. a batch of sequences with a common prefix, this condition should be satisfied.
k_lens = torch.randint(
low=chunk_starts[-2] + num_queries, high=MAX_T, size=(B,)
).tolist()
q_lens = [num_queries] * B
B_T = num_queries * B
q = torch.randn((1, B_T, G, N_H_L, D_H), dtype=dtype, device="cuda")
k = torch.randn((B, MAX_T, G, 1, D_H), dtype=dtype, device="cuda")
v = torch.randn_like(k)
if not bmghk:
q = q[:, :, 0]
# Compute per-chunk attention
chunks_output = []
for i in range(num_chunks):
chunk_start, chunk_end = chunk_starts[i], chunk_starts[i + 1]
k_chunk = k[:, chunk_start:chunk_end, ...]
v_chunk = v[:, chunk_start:chunk_end, ...]
axk = k_chunk.reshape(-1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H)
axv = v_chunk.reshape(-1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H)
if not bmghk:
axk = axk[:, :, 0]
axv = axv[:, :, 0]
bias_type = fmha.attn_bias.BlockDiagonalPaddedKeysMask
if i + 1 == num_chunks:
bias_type = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask
attn_bias = bias_type.from_seqlens(
q_seqlen=q_lens,
kv_padding=chunk_end - chunk_start,
kv_seqlen=[max(min(x, chunk_end) - chunk_start, 0) for x in k_lens],
)
attn_chunk, lse_chunk = fmha.memory_efficient_attention_partial(
q,
axk,
axv,
attn_bias,
op=op,
output_dtype=output_dtype,
)
if bmghk:
assert attn_chunk.shape == (1, B_T, G, N_H_L, D_H)
assert lse_chunk.shape == (1, G, N_H_L, B_T)
else:
assert attn_chunk.shape == (1, B_T, N_H_L, D_H)
assert lse_chunk.shape == (1, N_H_L, B_T)
chunks_output.append((attn_chunk, lse_chunk))
# Merge attention from all chunks
attn_split = [attn_chunk for attn_chunk, _ in chunks_output]
lse_split = [lse_chunk for _, lse_chunk in chunks_output]
attn_split_ = torch.stack(attn_split) if stack_inputs else attn_split
lse_split_ = torch.stack(lse_split) if stack_inputs else lse_split
attn_out, lse_out = fmha.merge_attentions(
attn_split_, lse_split_, output_dtype=dtype # type: ignore
)
assert lse_out is not None
# Compute attention on the full K/V
attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=q_lens,
kv_padding=MAX_T,
kv_seqlen=k_lens,
)
axk = k.view(1, -1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H)
axv = v.view(1, -1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H)
if not bmghk:
axk = axk[:, :, 0]
axv = axv[:, :, 0]
attn_full, lse_full = fmha.memory_efficient_attention_partial(
q,
axk,
axv,
attn_bias,
op=op,
output_dtype=output_dtype,
)
assert_allclose(
lse_out.to(lse_full.dtype), lse_full, rtol=1e-3, atol=1e-3, msg="lse"
)
assert_allclose(
attn_out.to(attn_full.dtype), attn_full, rtol=1e-3, atol=1e-3, msg="out"
)
attn_full2 = fmha.memory_efficient_attention_forward(
q,
axk,
axv,
attn_bias,
op=op,
output_dtype=output_dtype,
)
assert_allclose(attn_full2, attn_full, rtol=1e-3, atol=1e-3, msg="out2")
@disable_on_rocm
@sm80_or_better_only
@pytest.mark.parametrize(
"dtype,op",
[
(torch.bfloat16, fmha.triton_splitk.FwOp_S1),
(torch.bfloat16, fmha.triton_splitk.FwOp_S32),
],
ids=lambda o: f"{o.NAME}" if hasattr(o, "NAME") else str(o),
)
@pytest.mark.parametrize("gqa", [False, True], ids=lambda x: "gqa" if x else "")
def test_merge_attentions_sharedinput(
dtype: torch.dtype,
op: Type[AttentionFwOpBase],
gqa: bool,
):
"""
Compute decoding attention on chunks of K/V and merge them together.
Compare with computing attention on the whole K/V.
"""
MAX_T = 8192
N_H_L = 16
D_H = 128
G = 2
torch.manual_seed(1)
output_dtype = torch.float32 if op.SUPPORTS_OUTPUT_DTYPE else None
shared_length = 20
full_lengths = [30, 35, 40]
attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=[1, 1, 1],
kv_padding=MAX_T,
kv_seqlen=full_lengths,
)
attn_bias1 = fmha.attn_bias.BlockDiagonalPaddedKeysMask.from_seqlens(
q_seqlen=[2, 1],
kv_padding=MAX_T,
kv_seqlen=[shared_length, 0],
)
attn_bias2 = fmha.attn_bias.BlockDiagonalGappyKeysMask.from_seqlens(
q_seqlen=[1, 1, 1],
kv_seqstarts=[shared_length, MAX_T + shared_length, 2 * MAX_T, 3 * MAX_T],
kv_seqlen=[
full_lengths[0] - shared_length,
full_lengths[1] - shared_length,
full_lengths[2],
],
)
q = torch.randn((1, 3, G, N_H_L, D_H), dtype=dtype, device="cuda")
k = torch.randn((3, MAX_T, G, 1 if gqa else N_H_L, D_H), dtype=dtype, device="cuda")
v = torch.randn_like(k)
k[1, :shared_length] = k[0, :shared_length]
v[1, :shared_length] = v[0, :shared_length]
k = k.flatten(end_dim=1)[None]
v = v.flatten(end_dim=1)[None]
k = k.expand((1, 3 * MAX_T, G, N_H_L, D_H))
v = v.expand((1, 3 * MAX_T, G, N_H_L, D_H))
attn_chunk1, lse_chunk1 = fmha.memory_efficient_attention_partial(
q,
k,
v,
attn_bias1,
op=op,
output_dtype=output_dtype,
)
assert attn_chunk1.shape == (1, 3, G, N_H_L, D_H)
assert lse_chunk1.shape == (1, G, N_H_L, 3)
if gqa:
attn_chunk1a, lse_chunk1a = fmha.memory_efficient_attention_partial(
q,
k.contiguous(),
v,
attn_bias1,
op=op,
output_dtype=output_dtype,
)
assert attn_chunk1a.shape == (1, 3, G, N_H_L, D_H)
assert lse_chunk1a.shape == (1, G, N_H_L, 3)
assert_allclose(
attn_chunk1a.nan_to_num(0, 0, 0), attn_chunk1.nan_to_num(0, 0, 0)
)
assert_allclose(lse_chunk1a.nan_to_num(0, 0, 0), lse_chunk1.nan_to_num(0, 0, 0))
attn_chunk2, lse_chunk2 = fmha.memory_efficient_attention_partial(
q,
k,
v,
attn_bias2,
op=op,
output_dtype=output_dtype,
)
assert attn_chunk2.shape == (1, 3, G, N_H_L, D_H)
assert lse_chunk2.shape == (1, G, N_H_L, 3)
# Merge attention from all chunks
attn_out, lse_out = fmha.merge_attentions(
[attn_chunk1, attn_chunk2], [lse_chunk1, lse_chunk2], output_dtype=dtype # type: ignore
)
assert lse_out is not None
# Compute attention on the full K/V
attn_full, lse_full = fmha.memory_efficient_attention_partial(
q,
k,
v,
attn_bias,
op=op,
output_dtype=output_dtype,
)
assert_allclose(
attn_out.to(attn_full.dtype), attn_full, rtol=1e-2, atol=2e-3, msg="out"
)
assert_allclose(
lse_out.to(lse_full.dtype), lse_full, rtol=1e-3, atol=1e-3, msg="lse"
)
@sm80_or_better_only
@pytest.mark.parametrize("bmghk", (False, True))
@pytest.mark.parametrize(
"stack_inputs", (False, True), ids=lambda x: "stack_inputs" if x else ""
)
@pytest.mark.parametrize(
"grad_var", ("lse", "attn", None)
) # Gradient with respect to attention, LSE, or neither
def test_merge_attentions_against_ref(
bmghk: bool, stack_inputs: bool, grad_var: Optional[str]
):
split_k = 16
B = 12
M = 137
G = 2 if bmghk else 1
N_H_L = 8
D_H = 128
dtype = torch.float32
attn_split = torch.randn([split_k, B, M, G, N_H_L, D_H], dtype=dtype, device="cuda")
lse_split = torch.randn([split_k, B, G, N_H_L, M], dtype=dtype, device="cuda")
if not bmghk:
attn_split = attn_split[:, :, :, 0]
lse_split = lse_split[:, :, 0]
if grad_var is not None:
attn_split.requires_grad_(True)
lse_split.requires_grad_(True)
attn_out_ref, lse_out_ref = _merge_attentions_ref(attn_split, lse_split)
if grad_var is not None:
if grad_var == "attn":
out_grad = torch.randn_like(attn_out_ref)
attn_out_ref.backward(out_grad)
else:
out_grad = torch.randn_like(lse_out_ref)
lse_out_ref.backward(out_grad)
attn_grad_ref, lse_grad_ref = attn_split.grad, lse_split.grad
attn_split = attn_split.detach().unbind(0) # type: ignore
lse_split = lse_split.detach().unbind(0) # type: ignore
for x in attn_split + lse_split:
x.requires_grad_(True)
x.retain_grad()
attn_out, lse_out = fmha.merge_attentions(attn_split, lse_split)
torch.testing.assert_close(lse_out, lse_out_ref, rtol=1e-4, atol=1e-4)
torch.testing.assert_close(attn_out, attn_out_ref, rtol=1e-4, atol=1e-4)
if grad_var is not None:
if grad_var == "attn":
attn_out.backward(out_grad)
else:
assert lse_out is not None
lse_out.backward(out_grad)
attn_grads = [x.grad for x in attn_split]
lse_grads = [x.grad for x in lse_split]
attn_grad_concat = torch.stack(attn_grads, dim=0)
lse_grad_concat = torch.stack(lse_grads, dim=0)
if grad_var == "lse":
# LSE doesn't depend on attn_split, so when only gradient with respect to LSE is provided as input,
# the output gradient with respect to attn_split is zero.
# The reference implementation produced None instead of zero in this case
attn_grad_ref = torch.zeros_like(attn_grad_concat)
torch.testing.assert_close(lse_grad_concat, lse_grad_ref, rtol=1e-4, atol=1e-4)
torch.testing.assert_close(
attn_grad_concat, attn_grad_ref, rtol=1e-4, atol=1e-4
)
def _merge_attentions_ref(attn_split, lse_split):
"""
attn_split: [split_k, B, M, (G,) H, Kq]
lse_split: [split_k, B, (G,) H, M]
"""
is_bmghk = len(attn_split.shape) == 6
if not is_bmghk:
attn_split = attn_split.unsqueeze(3)
lse_split = lse_split.unsqueeze(2)
lse_split = lse_split[..., None].moveaxis(4, 2) # [split_k, B, M, G, H, 1]
lse_max, _ = torch.max(lse_split, dim=0) # [B, M, G, H, 1]
sumexp_normalized = torch.exp(lse_split - lse_max) # [split_k, B, M, G, H, 1]
denominator = sumexp_normalized.sum(dim=0) # [B, M, G, H, 1]
numerator = (sumexp_normalized * attn_split).sum(dim=0) # [B, M, G, H, K]
attn_out = numerator / denominator # [B, M_ceil, G, H, Kq]
lse_out = lse_max + torch.log(denominator)
lse_out = lse_out.squeeze(4).permute(0, 2, 3, 1) # [B, G, H, M]
if not is_bmghk:
attn_out = attn_out.squeeze(2)
lse_out = lse_out.squeeze(1)
return attn_out, lse_out
@sm80_or_better_only
@skip_if_rocm # rocm doesn't support backward yet
@pytest.mark.parametrize(
"bias_t",
[None, fmha.attn_bias.LowerTriangularMask, fmha.attn_bias.BlockDiagonalMask],
)
@pytest.mark.parametrize("create_bias_inside_compiled", [False, True])
@pytest.mark.parametrize(
"op",
[None, (fmha.flash.FwOp, fmha.flash.BwOp), (fmha.cutlass.FwOp, fmha.flash.BwOp)],
)
def test_memeff_compile(bias_t, create_bias_inside_compiled: bool, op) -> None:
torch.manual_seed(0)
torch._dynamo.reset_code_caches() # avoids hitting recompilation limit
B, M, H, K = 1, 256, 2, 64
q, k, v, bias = create_tensors(
op if op is None else op[0],
"cuda",
torch.float16,
bias_t,
B,
M,
M,
H,
K,
K,
fmt="BMHK",
)
grad = torch.randn_like(q)
if create_bias_inside_compiled:
bias = None
if bias_t not in [None, fmha.attn_bias.LowerTriangularMask]:
pytest.skip("Can't create this mask inside compile")
if bias is not None:
bias.to(q.device)
q.requires_grad_(True)
k.requires_grad_(True)
v.requires_grad_(True)
def fmha_fn(q, k, v, bias):
if create_bias_inside_compiled and bias_t is not None:
bias = bias_t()
return fmha.memory_efficient_attention(q, k, v, attn_bias=bias, op=op)
# Eager reference
out_ref = fmha_fn(q, k, v, bias)
out_ref.backward(grad)
dq_ref, dk_ref, dv_ref = q.grad, k.grad, v.grad
q.grad, k.grad, v.grad = None, None, None
# Compiled version
fmha_c = torch.compile(fmha_fn, fullgraph=True, dynamic=False)
out = fmha_c(q, k, v, bias)
out.backward(grad)
assert_allclose(
out,
out_ref,
"out",
atol=fmha.flash.FwOp.ERROR_ATOL[q.dtype],
rtol=fmha.flash.FwOp.ERROR_RTOL[q.dtype],
)
atol, rtol = (
fmha.flash.BwOp.ERROR_ATOL[q.dtype],
fmha.flash.BwOp.ERROR_RTOL[q.dtype],
)
assert_allclose(q.grad, dq_ref, "dq", atol=atol, rtol=rtol)
assert_allclose(k.grad, dk_ref, "dk", atol=atol, rtol=rtol)
assert_allclose(v.grad, dv_ref, "dv", atol=atol, rtol=rtol)
def test_bias_lower_triangular() -> None:
mask = fmha.attn_bias.LowerTriangularMask()
mask.detach()
def test_bias_lower_triangular_with_bias() -> None:
dense_bias = torch.randn([128, 128], dtype=torch.float16, requires_grad=True)
grad = torch.randn_like(dense_bias)
mask = fmha.attn_bias.LowerTriangularMask()
mask_biased = mask.add_bias(dense_bias)
mask_biased2 = mask_biased.detach()
mask_biased.backward(grad)
assert dense_bias.grad is not None
assert mask_biased2.grad is None
assert_allclose(dense_bias.grad, grad, "dense.grad")
# end of file