Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import math | |
import random | |
from functools import partial | |
from typing import Any, List, Optional, Sequence, Tuple, Type, TypeVar, Union | |
import pytest | |
import torch | |
import torch.nn.functional as F | |
from scipy.stats import binomtest | |
from torch.utils.checkpoint import checkpoint | |
import xformers.ops | |
from xformers.attn_bias_utils import create_attn_bias | |
from xformers.ops import fmha | |
from xformers.ops.fmha import ALL_BW_OPS, ALL_FW_OPS | |
from xformers.ops.fmha.common import AttentionFwOpBase, AttentionOpBase | |
from xformers.ops.fmha.dispatch import _dispatch_fw_priority_list | |
from .utils import ( | |
assert_allclose, | |
cuda_only, | |
disable_on_rocm, | |
disable_tf32, | |
pack_kv_cache, | |
ref_attention_bmhk_for_test, | |
ref_attention_for_test, | |
rocm_only, | |
) | |
compute_capability = (0, 0) | |
if torch.cuda.is_available(): | |
compute_capability = torch.cuda.get_device_capability("cuda") | |
sm70_or_better_only = pytest.mark.skipif( | |
compute_capability < (7, 0), reason="requires sm70+" | |
) | |
sm75_or_better_only = pytest.mark.skipif( | |
compute_capability < (7, 5), reason="requires sm75+" | |
) | |
sm80_or_better_only = pytest.mark.skipif( | |
compute_capability < (8, 0), reason="requires sm80+" | |
) | |
skip_if_rocm = pytest.mark.skipif( | |
torch.version.hip is not None, reason="not supported on ROCm" | |
) | |
skip_if_pt_cutlass = pytest.mark.skipif( | |
fmha.cutlass.USE_TORCH_CUTLASS, reason="using PT cutlass" | |
) | |
_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] | |
T = TypeVar( | |
"T", Type[fmha.common.AttentionFwOpBase], Type[fmha.common.AttentionBwOpBase] | |
) | |
logger = logging.getLogger("xformers") | |
def _filter_unsupported_ops(ops: Sequence[T]) -> List[T]: | |
return [ | |
op | |
for op in ops | |
if ( | |
"cpu" in op.SUPPORTED_DEVICES | |
or op.CUDA_MINIMUM_COMPUTE_CAPABILITY <= compute_capability | |
) | |
and op.is_available() | |
] | |
ALL_FW_OPS_NO_UNPADDED_LSE = _filter_unsupported_ops(ALL_FW_OPS) | |
ALL_FW_OPS = _filter_unsupported_ops( | |
ALL_FW_OPS | |
+ ( | |
[fmha.flash.FlashFwUnpaddedLSE] | |
if fmha.flash.FLASH_SUPPORTS_UNPADDED_LSE | |
else [] | |
) | |
) | |
ALL_BW_OPS = _filter_unsupported_ops(ALL_BW_OPS) | |
def sample_random_supported_fw( | |
inp: fmha.Inputs, seed: int | |
) -> Type[fmha.common.AttentionFwOpBase]: | |
r = random.Random(seed) | |
fw_ops = list(ALL_FW_OPS) | |
r.shuffle(fw_ops) | |
for op in fw_ops: | |
if op.supports(inp): | |
return op | |
raise NotImplementedError(f"Could not find a FW operator for: {inp}") | |
def generate_test_shapes_B_Mq_Mkv_H_K_Kv(op): | |
shapes = [] | |
for B in op._TEST_BATCH_SIZES: | |
for Mq in [32, 256]: | |
for Mkv in [32, 64, 256, 1024]: | |
for K in op._TEST_K: | |
shapes.append((B, Mq, Mkv, 1, K, K)) | |
Mq = 256 | |
Mkv = 128 | |
K = 32 | |
H = 1 | |
# Weird values of parameters | |
for M in [2, 3, 15, 31, 32, 34, 68, 72, 90, 132, 136]: | |
shapes.append((B, M, Mkv, H, K, K)) | |
shapes.append((B, Mq, M, H, K, K)) | |
for _K in [1, 2, 3, 31, 34, 36, 38, 40, 64, 80, 160, 256 + 2, 256 + 8, 512]: | |
if _K <= op.SUPPORTED_MAX_K: | |
shapes.append((B, Mq, Mkv, H, _K, _K)) | |
# Different value for K / Kv | |
if op.SUPPORTS_DIFFERENT_VALUE_EMBED: | |
for _K in [32, 36, 64, 256 + 8]: | |
shapes.append((B, Mq, Mkv, H, K, _K)) | |
shapes.append((B, Mq, Mkv, H, _K, K)) | |
# Exotic sizes | |
for K in op._TEST_K: | |
shapes.append((B, 16, 1024, H, K, K)) | |
shapes.append((B, 1024, 16, H, K, K)) | |
# Some number of heads | |
for H in [3, 5, 12]: | |
shapes.append((max(1, B // H), Mq, Mkv, H, K, K)) | |
# Filter-out not supported shapes | |
shapes = [ | |
shape | |
for shape in shapes | |
if len( | |
op.shape_not_supported_reasons( | |
Mq=shape[1], Mkv=shape[2], K=shape[4], Kv=shape[5] | |
) | |
) | |
== 0 | |
] | |
# Add some random shapes | |
if op in [ | |
fmha.cutlass.FwOp, | |
fmha.cutlass.BwOp, | |
fmha.flash.BwOp, | |
fmha.ck.FwOp, | |
]: | |
K_CHOICES = [8 * i for i in range(1, 256 // 8)] | |
r = random.Random(0) | |
found_count = 0 | |
while found_count < 200: | |
B = r.randint(1, 400) | |
Mq = r.randint(1, 500) | |
Mkv = r.randint(1, 500) | |
H = r.randint(2, 11) | |
B = max(B // H, 1) | |
K = r.choice(K_CHOICES) | |
Kv = r.choice(K_CHOICES) | |
if not op.SUPPORTS_DIFFERENT_VALUE_EMBED: | |
Kv = K | |
if len(op.shape_not_supported_reasons(Mq, Mkv, K, Kv)): | |
continue | |
found_count += 1 | |
shapes.append((B, Mq, Mkv, H, K, Kv)) | |
return shapes | |
def make_id(op, device, dtype, bias_type, *shape): | |
return ( | |
f"{op.NAME}-{device}-{str(dtype)}-{bias_type.__name__}" | |
f"-{'-'.join([str(s) for s in shape])}" | |
) | |
def _generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv( | |
ops_list: Sequence[Type[fmha.AttentionOpBase]], max_shapes_per_op: int = 65000 | |
): | |
r = random.Random(0) | |
combination = [] | |
ids = [] | |
for op in ops_list: | |
op_count = 0 | |
# Sort list of masks, so it's deterministic across runs | |
LIST_MASKS = list(sorted(op.SUPPORTED_ATTN_BIAS_TYPES, key=lambda x: str(x))) | |
for shape in generate_test_shapes_B_Mq_Mkv_H_K_Kv(op): | |
has_one = False | |
for device in _devices: | |
if device not in op.SUPPORTED_DEVICES: | |
continue | |
for dtype in op.SUPPORTED_DTYPES: | |
bias_type = r.choice(LIST_MASKS) | |
# Avoid using too much memory | |
if bias_type not in [ | |
type(None), | |
fmha.attn_bias.LowerTriangularMask, | |
]: | |
B, Mq, Mkv, H, K, Kv = shape | |
B = min(B, 12) | |
if bias_type in { | |
fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask, | |
fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask, | |
}: | |
Mq, Mkv = min(Mkv, Mq), max(Mkv, Mq) + 2 | |
elif bias_type in { | |
fmha.attn_bias.BlockDiagonalCausalWithOffsetGappyKeysMask, | |
fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask, | |
fmha.attn_bias.BlockDiagonalPaddedKeysMask, | |
fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, | |
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask, | |
}: | |
Mq, Mkv = min(Mkv, Mq), max(Mkv, Mq) | |
shape = (B, Mq, Mkv, H, K, Kv) | |
combination.append((op, device, dtype, bias_type, *shape)) | |
ids.append( | |
f"{op.NAME}-{device}-{str(dtype)}-{bias_type.__name__}" | |
f"-{'-'.join([str(s) for s in shape])}" | |
) | |
has_one = True | |
if has_one: | |
op_count += 1 | |
if op_count > max_shapes_per_op: | |
break | |
# Some specific shapes for which we want to run without any mask | |
bias_type = type(None) | |
for shape in ( | |
# Some strides/dims don't fit on an uint16 | |
(1, 128, 128, 300, 128, 128), | |
(13, 1, 67, 200, 8, 8), | |
(1, 1 + 2**16, 4, 1, 8, 8), | |
(1, 4, 1 + 2**16, 1, 8, 8), | |
# TODO: Some strides don't fit on an uint32 | |
# Crashes on Flash, Errors on Cutlass | |
# (1, 1, 64000, 300, 128, 128) | |
): | |
for device in _devices: | |
if device not in op.SUPPORTED_DEVICES: | |
continue | |
for dtype in op.SUPPORTED_DTYPES: | |
combination.append((op, device, dtype, bias_type, *shape)) | |
return { | |
"argvalues": combination, | |
"ids": [make_id(*c) for c in combination], | |
} | |
parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv = pytest.mark.parametrize( | |
"opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv", | |
**_generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(ALL_FW_OPS), | |
) | |
parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv_NO_UNPADDED_LSE = ( | |
pytest.mark.parametrize( | |
"opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv", | |
**_generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(ALL_FW_OPS_NO_UNPADDED_LSE), | |
) | |
) | |
parametrize_opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs = pytest.mark.parametrize( | |
"opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv", | |
**_generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(ALL_FW_OPS, max_shapes_per_op=1), | |
) | |
parametrize_opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv = pytest.mark.parametrize( | |
"opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv", | |
**_generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(ALL_BW_OPS), | |
) | |
parametrize_opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv__xs = pytest.mark.parametrize( | |
"opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv", | |
**_generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(ALL_BW_OPS, max_shapes_per_op=1), | |
) | |
def _rand_partition(r: random.Random, total: int, n: int) -> List[int]: | |
# returns list of n nonnegative integers summing to total | |
idx = {0, total} | |
while len(idx) < n + 1: | |
idx.add(r.randint(1, total - 1)) | |
s = sorted(idx) | |
return [e - b for b, e in zip(s[:-1], s[1:])] | |
def get_bias_grad(attn_bias, clear: bool = False) -> Optional[torch.Tensor]: | |
tensor_with_grad: Optional[torch.Tensor] = None | |
if isinstance(attn_bias, torch.Tensor): | |
tensor_with_grad = attn_bias | |
if tensor_with_grad is not None: | |
grad = tensor_with_grad.grad | |
if clear: | |
tensor_with_grad.grad = None | |
return grad | |
return None | |
def create_tensors( | |
op: Optional[Type[AttentionOpBase]], | |
device, | |
dtype, | |
attn_bias_type, | |
B, | |
q_len, | |
kv_len, | |
h, | |
k, | |
kv, | |
*, | |
attn_bias_requires_grad: bool = False, | |
fmt: str = "BMK", | |
g: int = 1, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]: | |
torch.manual_seed(B * q_len + kv_len * k + kv) | |
mask_is_bottom_right = attn_bias_type is not None and issubclass( | |
attn_bias_type, | |
( | |
fmha.attn_bias.LowerTriangularFromBottomRightMask, | |
fmha.attn_bias.LowerTriangularFromBottomRightLocalAttentionMask, | |
fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask, | |
fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask, | |
fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask, | |
fmha.attn_bias.LocalAttentionFromBottomRightMask, | |
), | |
) | |
if mask_is_bottom_right and q_len > kv_len: | |
# Bottom-right attention and local-attention masks require q_len <= kv_len | |
kv_len = q_len | |
if attn_bias_type is not None and issubclass( | |
attn_bias_type, | |
( | |
fmha.attn_bias.PagedBlockDiagonalGappyKeysMask, | |
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask, | |
), | |
): | |
page_size_choices = [256, 512] | |
if op is not None and issubclass(op, fmha.triton_splitk.FwOp): | |
# TODO: enable small pages for flash attention when that's implemented | |
page_size_choices.extend([64, 128]) | |
page_size = random.choice(page_size_choices) | |
kv_len_paged = (kv_len + page_size - 1) // page_size * page_size | |
else: | |
kv_len_paged = kv_len | |
page_size = None | |
scale = 3 | |
if fmt == "BMK": | |
query = torch.randn((B * h, q_len, k), device=device, dtype=dtype) | |
key = torch.randn((B * h, kv_len_paged, k), device=device, dtype=dtype) | |
value = torch.randn((B * h, kv_len_paged, kv), device=device, dtype=dtype) | |
elif fmt == "BMHK": | |
query = torch.randn((B, q_len, h, k), device=device, dtype=dtype) | |
key = torch.randn((B, kv_len_paged, h, k), device=device, dtype=dtype) | |
value = torch.randn((B, kv_len_paged, h, kv), device=device, dtype=dtype) | |
else: | |
assert fmt == "BMGHK" | |
query = torch.randn((B, q_len, g, h, k), device=device, dtype=dtype) | |
key = torch.randn((B, kv_len_paged, g, 1, k), device=device, dtype=dtype) | |
value = torch.randn((B, kv_len_paged, g, 1, kv), device=device, dtype=dtype) | |
for x in [query, key, value]: | |
x.mul_(scale) | |
if fmt == "BMGHK": | |
# Expand - after the in-place mul | |
key = key.expand((B, kv_len_paged, g, h, k)) | |
value = value.expand((B, kv_len_paged, g, h, k)) | |
if fmt == "BMK" and not fmha.common._is_bias_type_supported_in_BMK(attn_bias_type): | |
attn_bias_type = None | |
attn_bias = None | |
if attn_bias_type is not None: | |
attn_bias = create_attn_bias( | |
attn_bias_type, | |
batch_size=B, | |
num_heads=h, | |
num_heads_groups=g, | |
q_len=q_len, | |
kv_len=kv_len, | |
dtype=dtype, | |
device=device, | |
requires_grad=attn_bias_requires_grad, | |
fmt=fmt, | |
op=op, | |
page_size=page_size, | |
) | |
if isinstance( | |
attn_bias, | |
( | |
fmha.attn_bias.BlockDiagonalMask, | |
fmha.attn_bias.BlockDiagonalGappyKeysMask, | |
fmha.attn_bias.BlockDiagonalPaddedKeysMask, | |
fmha.attn_bias.PagedBlockDiagonalGappyKeysMask, | |
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask, | |
), | |
): | |
query, key, value = [ | |
x.reshape([1, -1, *x.shape[2:]]) for x in [query, key, value] | |
] | |
inputs = fmha.Inputs(query=query, key=key, value=value, attn_bias=attn_bias) | |
if op is not None: | |
reasons = op.not_supported_reasons(inputs) | |
if reasons: | |
err_msg = f"{op.NAME}: unsupported ({'/'.join(reasons)})" | |
# Ensure we free memory to avoid OOMs | |
del query, key, value, attn_bias, inputs | |
pytest.skip(err_msg) | |
return query, key, value, attn_bias | |
def bmhk2bmk(tensor) -> torch.Tensor: | |
return ( | |
tensor.permute((0, 2, 1, 3)) | |
.contiguous() | |
.view([tensor.shape[0] * tensor.shape[2], tensor.shape[1], tensor.shape[3]]) | |
) | |
def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor: | |
return tensor.reshape([-1, num_heads, tensor.shape[1], tensor.shape[2]]).permute( | |
(0, 2, 1, 3) | |
) | |
def nanify_oob_seqlen(x: torch.Tensor) -> torch.Tensor: | |
align_to = 256 | |
if x.shape[1] % align_to == 0: | |
return x | |
pad = [0, 0] * x.ndim | |
pad[-3] = align_to - (x.shape[1] % align_to) | |
x_pad = torch.nn.functional.pad(x, pad, value=math.nan) | |
return x_pad[:, : x.shape[1]] | |
def test_forward(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, packed, fmt, **kwargs): | |
( | |
op, | |
device, | |
dtype, | |
bias_type, | |
batch_size, | |
q_len, | |
kv_len, | |
h, | |
k, | |
kv, | |
) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv | |
if packed and issubclass( | |
bias_type, | |
( | |
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask, | |
fmha.attn_bias.PagedBlockDiagonalGappyKeysMask, | |
), | |
): | |
pytest.skip( | |
"packed doesn't make sense with paged attention, since q has different shape than k/v" | |
) | |
if packed and not (k == kv and q_len == kv_len): | |
pytest.skip( | |
f"packed incompatible with `k ({k}) != kv ({kv})` or `q_len ({q_len}) != kv_len ({kv_len})`" | |
) | |
if fmt == "BMK" and not fmha.common._is_bias_type_supported_in_BMK(bias_type): | |
pytest.skip("BMK incompatible with this bias") | |
query, key, value, attn_bias = create_tensors( | |
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
fmt="BMHK" if packed else fmt, | |
**kwargs, | |
) | |
if packed: | |
c = torch.stack([query, key, value], 2) | |
if fmt == "BMK": | |
# bm3hk -> 3bhmk -> 3Bmk | |
c = c.permute(2, 0, 3, 1, 4).view([3, -1, q_len, k]) | |
query, key, value = c[0], c[1], c[2] | |
# Re-create bias in the right format | |
attn_bias = create_attn_bias( | |
bias_type=bias_type, | |
batch_size=batch_size, | |
num_heads=h, | |
num_heads_groups=1, | |
q_len=q_len, | |
kv_len=kv_len, | |
device=device, | |
dtype=dtype, | |
requires_grad=False, | |
fmt=fmt, | |
op=op, | |
) | |
elif fmt == "BMHK": | |
# bm3hk -> 3 x bmhk | |
query, key, value = xformers.ops.unbind(c, 2) | |
else: | |
assert False, f"Unsupport fmt {fmt} with packing" | |
assert not query.is_contiguous() | |
out = xformers.ops.memory_efficient_attention_forward( | |
query, key, value, attn_bias, op=op | |
) | |
assert not out.isnan().any(), ("Output has NaNs", attn_bias) | |
out2 = xformers.ops.memory_efficient_attention_forward( | |
nanify_oob_seqlen(query), | |
nanify_oob_seqlen(key), | |
nanify_oob_seqlen(value), | |
attn_bias, | |
op=op, | |
) | |
assert not out2.isnan().any(), "Output has NaNs - most likely reading out-of-bounds" | |
assert torch.allclose(out, out2, atol=0.0, rtol=0.0), ( | |
"Non-deterministic behavior", | |
attn_bias, | |
) | |
ref = ref_attention_for_test(query, key, value, attn_bias) | |
assert out.shape == ref.shape, out.shape | |
assert_allclose( | |
out.float(), | |
ref, | |
atol=op.ERROR_ATOL[dtype], | |
rtol=op.ERROR_RTOL.get(dtype, 1e-5), | |
) | |
def test_key_query_all_ones(q_len, kv_len, batch_size, k_len): | |
device = "cuda" | |
scale = 3 | |
# composable kernel doesn't support fp32 | |
dtype = torch.float16 if torch.version.hip else torch.float32 | |
query = torch.ones((batch_size, q_len, k_len), device=device, dtype=dtype) | |
key = torch.ones((batch_size, kv_len, k_len), device=device, dtype=dtype) | |
value = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale | |
out = xformers.ops.memory_efficient_attention(query, key, value) | |
# this should be equivalent to the average over value | |
ref = value.mean(1, keepdim=True).expand_as(query) | |
assert_allclose(out, ref, atol=1e-5) | |
def _block_diag_reshape_lse( | |
lse: torch.Tensor, q_seqinfo: fmha.attn_bias._SeqLenInfo | |
) -> torch.Tensor: | |
"""LSE can be padded, let's remove the padding""" | |
parts = [] | |
for slice, (start, end) in zip(lse.unbind(0), q_seqinfo.intervals()): | |
parts.append(slice[:, : end - start]) | |
return torch.cat(parts, dim=1).unsqueeze(0) | |
def test_logsumexp(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv): | |
( | |
op, | |
device, | |
dtype, | |
bias_type, | |
batch_size, | |
q_len, | |
kv_len, | |
h, | |
k, | |
kv, | |
) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv | |
if op is fmha.ck.FwOp: | |
pytest.skip("logsumexp is not yet supported by ck-tiled fmha!") | |
query, key, value, attn_bias = create_tensors( | |
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
fmt="BMHK", | |
) | |
_out, lse = xformers.ops.memory_efficient_attention_forward_requires_grad( | |
query, | |
key, | |
value, | |
op=op, | |
attn_bias=attn_bias, | |
) | |
query = query.transpose(1, 2) | |
key = key.transpose(1, 2) | |
attn = (query.float() / k**0.5) @ key.float().transpose(-2, -1) | |
if attn_bias is not None: | |
if isinstance( | |
attn_bias, | |
(fmha.attn_bias.AttentionBias, fmha.attn_bias.AttentionBiasSubTensor), | |
): | |
bias_shape = (1, 1, query.shape[2], key.shape[2]) | |
tensor_bias = attn_bias.materialize( | |
bias_shape, | |
device=query.device, | |
dtype=torch.float32, | |
) | |
else: | |
assert type(attn_bias) is torch.Tensor | |
tensor_bias = attn_bias | |
attn = attn + tensor_bias.float() | |
ref_lse = attn.logsumexp(-1) | |
if isinstance( | |
attn_bias, | |
( | |
fmha.attn_bias.BlockDiagonalMask, | |
fmha.attn_bias.BlockDiagonalGappyKeysMask, | |
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask, | |
fmha.attn_bias.BlockDiagonalPaddedKeysMask, | |
), | |
) and issubclass(op, (fmha.flash.FwOp, fmha.cutlass.FwOp)): | |
# Sometimes LSE is returned in padded format, i.e. (B, H, MAX_LEN) instead of (H, TOTAL_LEN). | |
# Unpad to compare with the reference. | |
# This is the case for Flash Attention when UNPADDED_LSE=False and for CUTLASS. | |
if op.UNPADDED_LSE: | |
lse = lse.unsqueeze(0) | |
else: | |
lse = _block_diag_reshape_lse(lse, attn_bias.q_seqinfo) | |
if op is fmha.cutlass.FwOp: | |
# CUTLASS kernel pads the last dimention of LSE to 32 | |
lse = lse[:, :, : ref_lse.shape[2]] | |
assert_allclose(lse, ref_lse, atol=2e-4) | |
def test_logsumexp_mqa(op): | |
if not op.is_available(): | |
pytest.skip("not available") | |
if op.CUDA_MINIMUM_COMPUTE_CAPABILITY > compute_capability: | |
skip_reason = ( | |
f"requires device with capability >= {op.CUDA_MINIMUM_COMPUTE_CAPABILITY} " | |
f"but your GPU has capability {compute_capability} (too old)" | |
) | |
pytest.skip(skip_reason) | |
dtype = torch.float16 | |
s = 3 | |
query = torch.randn([1, 1, 32, 128], dtype=dtype, device="cuda") * s | |
key = (torch.randn([1, 16, 1, 128], dtype=dtype, device="cuda") * s).expand( | |
-1, -1, 32, -1 | |
) | |
value = (torch.randn([1, 16, 1, 128], dtype=dtype, device="cuda") * s).expand( | |
-1, -1, 32, -1 | |
) | |
assert key.stride(2) == 0 | |
_, lse = xformers.ops.memory_efficient_attention_forward_requires_grad( | |
query, | |
key, | |
value, | |
op=op, | |
) | |
query, key, value = [x[0].transpose(0, 1) for x in [query, key, value]] | |
attn = (query.float() / query.shape[-1] ** 0.5) @ key.float().transpose(-2, -1) | |
ref_lse = attn.logsumexp(-1) | |
assert_allclose(lse[0, :, 0], ref_lse[:, 0], atol=2e-4) | |
def test_backward( | |
opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
grad_out_contiguous, | |
fmt, | |
): | |
( | |
op_bw, | |
device, | |
dtype, | |
bias_type, | |
batch_size, | |
q_len, | |
kv_len, | |
h, | |
k, | |
kv, | |
) = opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv | |
attn_bias_requires_grad = ( | |
random.Random(q_len + kv_len * batch_size).randint(0, 1) > 0 | |
) | |
query, key, value, attn_bias = create_tensors( | |
*opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
attn_bias_requires_grad=attn_bias_requires_grad, | |
fmt=fmt, | |
) | |
# To understand why we do this, check the comment on the | |
# `AttentionBwOpBase` class | |
scale = None | |
if op_bw.SUPPORTS_CUSTOM_SCALE and query.shape[-1] < 32: | |
scale = (1 / 32) ** 0.5 | |
op_fw = ( | |
sample_random_supported_fw( | |
fmha.Inputs(query=query, key=key, value=value, attn_bias=attn_bias), | |
seed=q_len * kv + kv_len * k, | |
) | |
if op_bw != fmha.cutlass.BwOp | |
else fmha.cutlass.FwOp | |
) | |
if op_bw == fmha.ck.BwOp: | |
op_fw = fmha.ck.FwOp | |
if dtype == torch.bfloat16: | |
pytest.skip( | |
"CK Fmha backward for bfloat16 currently is not very accurate for some cases!" | |
) | |
if grad_out_contiguous is False: | |
pytest.skip("CK Fmha does not support contiguous layout for grad_out!") | |
if k % 2 != 0: | |
pytest.skip( | |
"CK Fmha currently requires the headdim size of query input be an even value!" | |
) | |
qkv = None | |
if ( | |
fmt == "BMHK" | |
and query.shape[3] == value.shape[3] | |
and query.shape[1] == value.shape[1] | |
): | |
qkv = torch.stack([query, key, value], 2) | |
qkv.requires_grad_(True) | |
# bm3hk -> 3 x bmhk | |
query, key, value = xformers.ops.unbind(qkv, 2) | |
assert not query.is_contiguous() | |
query.requires_grad_(True) | |
key.requires_grad_(True) | |
value.requires_grad_(True) | |
if not op_bw.supports(fmha.Inputs(query, key, value, attn_bias)): | |
pytest.skip("inputs not supported") | |
out = xformers.ops.memory_efficient_attention( | |
query, key, value, attn_bias, scale=scale, op=(op_fw, op_bw) | |
) | |
grad_out = torch.randn_like(out) | |
if grad_out_contiguous is False: | |
grad_out = torch.tensor([1.0], dtype=query.dtype, device=device)[ | |
None, None, : | |
].expand_as(out) | |
out.backward(grad_out) | |
if qkv is None and op_bw == fmha.cutlass.BwOp: | |
assert query.stride() == query.grad.stride() | |
grads = [] | |
if qkv is None: | |
grads = [query.grad, key.grad, value.grad] | |
query.grad = None | |
key.grad = None | |
value.grad = None | |
else: | |
grads = [qkv.grad] | |
qkv.grad = None | |
if attn_bias_requires_grad: | |
attn_bias_grad = get_bias_grad(attn_bias, clear=True) | |
if attn_bias_grad is not None: | |
grads.append(attn_bias_grad) | |
ref = ref_attention_for_test(query, key, value, attn_bias, scale=scale) | |
ref.backward(grad_out) | |
assert_allclose( | |
out.float(), | |
ref.float(), | |
"fw pass", | |
atol=op_fw.ERROR_ATOL[dtype], | |
rtol=op_fw.ERROR_RTOL[dtype], | |
) | |
del out | |
del grad_out | |
del ref | |
atol = op_bw.ERROR_ATOL[dtype] | |
rtol = op_bw.ERROR_RTOL[dtype] | |
grads_ref = [] | |
grads_name = [] | |
if qkv is None: | |
assert isinstance(query.grad, torch.Tensor) | |
assert isinstance(key.grad, torch.Tensor) | |
assert isinstance(value.grad, torch.Tensor) | |
grads_ref = [query.grad, key.grad, value.grad] | |
grads_name = ["query", "key", "value"] | |
else: | |
assert isinstance(qkv.grad, torch.Tensor) | |
grads_ref = [qkv.grad] | |
grads_name = ["qkv"] | |
if attn_bias_requires_grad: | |
attn_bias_grad = get_bias_grad(attn_bias) | |
if attn_bias_grad is not None: | |
grads_ref.append(attn_bias.grad) | |
grads_name.append("bias") | |
del query | |
del key | |
del value | |
del qkv | |
assert len(grads_ref) == len( | |
grads | |
), "Wrong number of gradients (maybe bias grad didn't backprop?)" | |
for name, calc_grad, ref_grad in zip(grads_name, grads, grads_ref): | |
assert_allclose( | |
calc_grad, | |
ref_grad, | |
msg=f"{op_fw.NAME}+{op_bw.NAME}:{name}", | |
atol=atol, | |
rtol=rtol, | |
) | |
def _vec_binom_test(x, n, p): | |
""" | |
vectorized implementation of scipy.stats.binom_test | |
this makes our tests much faster | |
reference: https://github.com/scipy/scipy/blob/v1.8.0/scipy/stats/_morestats.py#L2609-L2702 | |
""" | |
import numpy as np | |
from scipy.stats import distributions | |
x = np.atleast_1d(x) | |
d = distributions.binom.pmf(x, n, p)[:, None] | |
rerr = 1 + 1e-7 | |
# x < p * n case | |
i = np.arange(np.ceil(p * n), n + 1) | |
y = np.sum(distributions.binom.pmf(i, n, p) <= d * rerr, axis=1) | |
pval1 = distributions.binom.cdf(x, n, p) + distributions.binom.sf(n - y, n, p) | |
# other case | |
i = np.arange(np.floor(p * n) + 1) | |
y = np.sum(distributions.binom.pmf(i, n, p) <= d * rerr, axis=1) | |
pval2 = distributions.binom.cdf(y - 1, n, p) + distributions.binom.sf(x - 1, n, p) | |
pval = np.where(x < p * n, pval1, pval2) | |
pval = np.minimum(1.0, pval) | |
return pval | |
def _get_drop_mask(op, batch_size, q_len, kv_len, p, device): | |
if op == fmha.cutlass.FwOp: | |
mask = torch.empty((batch_size, 1, q_len, kv_len), device=device) | |
rand_uniform = torch.ops.xformers._cutlass_rand_uniform(p, mask) | |
mask = (rand_uniform > p).to(torch.float32) | |
mask = mask.reshape(batch_size, q_len, kv_len) | |
elif op == fmha.ck.FwOp: | |
mask = torch.empty((batch_size, 1, q_len, kv_len), device=device) | |
# rand_uniform is an int8_t tensor | |
rand_uniform = torch.ops.xformers._ck_rand_uniform(p, mask) | |
mask = (rand_uniform <= int((1.0 - p) * 255.0)).to(torch.float32) | |
mask = mask.reshape(batch_size, q_len, kv_len) | |
else: | |
mask = torch.empty((batch_size, q_len, kv_len), device=device) | |
mask = torch.ops.xformers._temp_dropout(mask, p) | |
return mask | |
def test_dropout(op, q_len, kv_len, batch_size, k_len, p, seed, attn_bias): | |
device = "cuda" | |
scale = 3 | |
dtype = torch.float | |
if torch.version.hip and op == fmha.ck.FwOp: | |
dtype = torch.float16 | |
query = torch.randn((batch_size, q_len, k_len), device=device, dtype=dtype) * scale | |
key = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale | |
value = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale | |
inputs_for_support_check = fmha.Inputs(query, key, value, attn_bias, p, None) | |
if not op.supports(inputs_for_support_check): | |
del query, key, value, attn_bias | |
pytest.skip(f"{op.NAME}: unsupported input") | |
torch.manual_seed(seed) | |
out = xformers.ops.memory_efficient_attention( | |
query, key, value, attn_bias, p, op=(op, None) | |
) | |
torch.manual_seed(seed) | |
out2 = xformers.ops.memory_efficient_attention( | |
query, key, value, attn_bias, p, op=(op, None) | |
) | |
assert_allclose(out, out2, "dropout reproducibility") | |
torch.manual_seed(seed) | |
mask = _get_drop_mask(op, batch_size, q_len, kv_len, p, device) | |
ref = ref_attention_for_test(query, key, value, attn_bias, mask, p) | |
if dtype is torch.float: | |
assert_allclose(out, ref, atol=2e-4), f"{(out - ref).abs().max()}" | |
else: | |
assert_allclose(out.float(), ref, atol=2.8e-2), f"{(out - ref).abs().max()}" | |
num_trials = 1000 | |
p_val_tol = 1e-6 | |
keep_prob = 1 - p | |
masks = [] | |
for i in range(num_trials): | |
mask = _get_drop_mask(op, batch_size, q_len, kv_len, p, device) | |
masks.append(mask.clone().cpu()) | |
masks = torch.stack(masks, dim=0) | |
p_value = binomtest(int(masks.sum()), masks.numel(), p=keep_prob).pvalue | |
assert p_value > p_val_tol, p_value | |
masks = masks.sum(0).flatten() | |
p_values = _vec_binom_test(masks, num_trials, p=keep_prob) | |
assert all(p_values > p_val_tol) | |
def _test_dropout_backward(q_len, kv_len, batch_size, k, p, op, dtype): | |
if dtype is torch.bfloat16 and compute_capability < (8, 0): | |
pytest.skip("bf16 requires Sm80") | |
if not op.is_available(): | |
pytest.skip() | |
scale = 3 | |
device = "cuda" | |
query = torch.randn((batch_size, q_len, k), device=device, dtype=dtype) * scale | |
key = torch.randn((batch_size, kv_len, k), device=device, dtype=dtype) * scale | |
value = torch.randn((batch_size, kv_len, k), device=device, dtype=dtype) * scale | |
query.requires_grad_(True) | |
key.requires_grad_(True) | |
value.requires_grad_(True) | |
grad_out = torch.ones_like(query) | |
assert op.supports(fmha.Inputs(query=query, key=key, value=value, p=p)) | |
seed = 42 | |
torch.manual_seed(seed) | |
out = xformers.ops.memory_efficient_attention(query, key, value, p=p, op=(op, None)) | |
out.backward(grad_out) | |
grad_q = query.grad | |
grad_k = key.grad | |
grad_v = value.grad | |
query.grad = None | |
key.grad = None | |
value.grad = None | |
torch.manual_seed(seed) | |
mask = _get_drop_mask(op, batch_size, q_len, kv_len, p, device) | |
ref = ref_attention_for_test(query, key, value, None, mask, p) | |
ref.backward(grad_out) | |
atol, rtol = ( | |
fmha.AttentionBwOpBase.ERROR_ATOL[dtype], | |
fmha.AttentionBwOpBase.ERROR_RTOL[dtype], | |
) | |
assert_allclose( | |
grad_v, | |
value.grad, | |
"grad_v", | |
atol=atol, | |
rtol=rtol, | |
) | |
# TODO: Investigate why precision is worse | |
if dtype in [torch.float16, torch.bfloat16]: | |
atol = atol * 2 + 0.15 | |
rtol = rtol * 2 | |
assert_allclose( | |
grad_q, | |
query.grad, | |
"grad_q", | |
atol=atol, | |
rtol=rtol, | |
) | |
assert_allclose( | |
grad_k, | |
key.grad, | |
"grad_k", | |
atol=atol, | |
rtol=rtol, | |
) | |
def test_dropout_backward_small_k(q_len, kv_len, batch_size, k, p): | |
_test_dropout_backward( | |
q_len, kv_len, batch_size, k, p, op=fmha.small_k.FwOp, dtype=torch.float32 | |
) | |
def test_dropout_backward_cutlass(dt, q_len, kv_len, batch_size, k, p): | |
_test_dropout_backward( | |
q_len, | |
kv_len, | |
batch_size, | |
k, | |
p, | |
op=fmha.cutlass.FwOp, | |
dtype={"f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32}[dt], | |
) | |
def test_dropout_backward_ck(dt, q_len, kv_len, batch_size, k, p): | |
_test_dropout_backward( | |
q_len, | |
kv_len, | |
batch_size, | |
k, | |
p, | |
op=fmha.ck.FwOp, | |
dtype={"f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32}[dt], | |
) | |
def test_memory_efficient_attention_full_block_masked(q_len, kv_len, batch_size, k_len): | |
device = "cuda" | |
op_fw = fmha.small_k.FwOp | |
op_bw = fmha.small_k.BwOp | |
scale = 3 | |
query = torch.randn((batch_size, q_len, k_len), device=device) * scale | |
key = torch.randn((batch_size, kv_len, k_len), device=device) * scale | |
value = torch.randn((batch_size, kv_len, k_len), device=device) * scale | |
# in this case, most of the blocks in a row get masked | |
attn_bias = torch.full((3, 32), float("-inf"), device=device) | |
attn_bias[:2, :4] = 0 | |
attn_bias = attn_bias.flatten()[None, None, :].expand(1, q_len, -1) | |
out = xformers.ops.memory_efficient_attention( | |
query, key, value, attn_bias, op=(op_fw, op_bw) | |
) | |
ref = ref_attention_for_test(query, key, value, attn_bias) | |
assert_allclose( | |
out, ref, atol=op_fw.ERROR_ATOL[query.dtype], rtol=op_fw.ERROR_RTOL[query.dtype] | |
) | |
query.requires_grad_(True) | |
key.requires_grad_(True) | |
value.requires_grad_(True) | |
grad_out = torch.ones_like(query) | |
out = xformers.ops.memory_efficient_attention(query, key, value, attn_bias) | |
out.backward(grad_out) | |
grad_q = query.grad | |
grad_k = key.grad | |
grad_v = value.grad | |
query.grad = None | |
key.grad = None | |
value.grad = None | |
ref = ref_attention_for_test(query, key, value, attn_bias) | |
ref.backward(grad_out) | |
atol = op_bw.ERROR_ATOL[query.dtype] | |
rtol = op_bw.ERROR_RTOL[query.dtype] | |
assert_allclose(grad_q, query.grad, "grad_q", atol=atol, rtol=rtol) | |
assert_allclose(grad_k, key.grad, "grad_k", atol=atol, rtol=rtol) | |
assert_allclose(grad_v, value.grad, "grad_v", atol=atol, rtol=rtol) | |
def test_lowlevel_api_shapes(opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt): | |
query, key, value, attn_bias = create_tensors( | |
*opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt=fmt | |
) | |
grad_out = torch.ones_like(query) | |
query.requires_grad_(True) | |
key.requires_grad_(True) | |
value.requires_grad_(True) | |
out, lse = xformers.ops.memory_efficient_attention_forward_requires_grad( | |
query, key, value, attn_bias | |
) | |
assert out.ndim == query.ndim | |
dq, dk, dv = xformers.ops.memory_efficient_attention_backward( | |
grad_out, out, lse, query, key, value, attn_bias | |
) | |
assert dq.shape == query.shape | |
assert dk.shape == key.shape | |
assert dv.shape == value.shape | |
def test_cuda_streams( | |
opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
): | |
( | |
op, | |
device, | |
dtype, | |
bias_type, | |
batch_size, | |
q_len, | |
kv_len, | |
h, | |
k, | |
kv, | |
) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv | |
if device != "cuda": | |
pytest.skip("Not CUDA") | |
bias_type = None | |
opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv = [ | |
op, | |
device, | |
dtype, | |
bias_type, | |
batch_size, | |
q_len, | |
kv_len, | |
h, | |
k, | |
kv, | |
] | |
s_hipri = torch.cuda.Stream(priority=-1) | |
s_lopri = torch.cuda.Stream(priority=0) | |
query, key, value, attn_bias = create_tensors( | |
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt="BMHK" | |
) | |
torch.cuda.synchronize() | |
with torch.cuda.stream(s_lopri): | |
torch.cuda._sleep(100_000_000) # wait 100m cycles | |
query *= 2 | |
s_hipri.wait_stream(s_lopri) | |
with torch.cuda.stream(s_hipri): | |
# If the kernel is scheduled in the main stream | |
# `query * 2` has not been executed yet | |
out = xformers.ops.memory_efficient_attention(query, key, value, op=(op, None)) | |
# Test that `s_lopri` is still sleeping | |
# and that `query *= 2` has not been executed yet | |
query2_main_stream = query * 2 | |
torch.cuda.synchronize() | |
# TODO: Figure out why this is failing sometimes | |
# The sleep timer seems to be high enough already ... | |
# assert torch.allclose(query2_main_stream, query), "Need to increase sleep time" | |
del query2_main_stream | |
ref = ref_attention_for_test(query, key, value) | |
assert out.shape == ref.shape, out.shape | |
assert_allclose( | |
out.float(), | |
ref.float(), | |
atol=op.ERROR_ATOL[dtype], | |
rtol=op.ERROR_RTOL.get(dtype, 1e-5), | |
) | |
def test_custom_scale(opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv): | |
p = 0.0 | |
scale = 0.1 | |
( | |
op_bw, | |
device, | |
dtype, | |
_, | |
B, | |
q_len, | |
kv_len, | |
H, | |
k, | |
Kv, | |
) = opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv | |
torch.manual_seed(q_len + kv_len + k) | |
if device != "cuda": | |
pytest.skip("Not CUDA") | |
query, key, value, attn_bias = create_tensors( | |
*opBW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, fmt="BMK" | |
) | |
inputs = fmha.Inputs( | |
query=query, key=key, value=value, attn_bias=attn_bias, scale=scale | |
) | |
op_fw = sample_random_supported_fw(inputs, seed=q_len * k + kv_len * k) | |
grad_out = query.new_ones(B * H, q_len, Kv) | |
query.requires_grad_(True) | |
key.requires_grad_(True) | |
value.requires_grad_(True) | |
reasons = op_fw.not_supported_reasons(inputs) | |
if reasons: | |
pytest.skip(f"{op_fw.NAME}: unsupported ({'/'.join(reasons)})") | |
reasons = op_bw.not_supported_reasons(inputs) | |
if reasons: | |
pytest.skip(f"{op_bw.NAME}: unsupported ({'/'.join(reasons)})") | |
# NOTE: we still need to scale the inputs to not blowup | |
# the pre-softmax values (numerical stability) | |
s = k**-0.5 | |
out = xformers.ops.memory_efficient_attention( | |
query * s, key, value, attn_bias, p, scale, op=(op_fw, op_bw) | |
) | |
out.backward(grad_out) | |
grad_q, grad_k, grad_v = query.grad, key.grad, value.grad | |
query.grad = key.grad = value.grad = None | |
ref = ref_attention_for_test(query * s, key, value, attn_bias, None, p, scale) | |
ref.backward(grad_out) | |
ref_grad_q, ref_grad_k, ref_grad_v = query.grad, key.grad, value.grad | |
query.grad = key.grad = value.grad = None | |
atol = op_fw.ERROR_ATOL[dtype] | |
rtol = op_fw.ERROR_RTOL[dtype] | |
assert_allclose(out.float(), ref.float(), "out", atol=atol, rtol=rtol) | |
atol = op_bw.ERROR_ATOL[dtype] | |
rtol = op_bw.ERROR_RTOL[dtype] | |
assert_allclose(grad_q, ref_grad_q, "grad_q", atol=atol, rtol=rtol) | |
assert_allclose(grad_k, ref_grad_k, "grad_k", atol=atol, rtol=rtol) | |
assert_allclose(grad_v, ref_grad_v, "grad_v", atol=atol, rtol=rtol) | |
def apply_attention(query, key, value, attn_bias, op_fw, proj): | |
x = xformers.ops.memory_efficient_attention( | |
query, key, value, attn_bias=attn_bias, op=(op_fw, None) | |
) | |
x = proj(x) | |
return x | |
def test_grad_checkpointing( | |
opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
use_reentrant, | |
): | |
fmt = "BMHK" | |
( | |
op, | |
device, | |
dtype, | |
bias_type, | |
batch_size, | |
q_len, | |
kv_len, | |
h, | |
k, | |
kv, | |
) = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv | |
if op is fmha.triton_splitk.FwOp: | |
pytest.skip("Triton Flash Decoding doesn't support backward pass yet") | |
if op is fmha.ck.FwOp: | |
pytest.skip("ck-tiled FMHA doesn't supported backward pass yet") | |
bias_type = None | |
opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv = ( | |
op, | |
device, | |
dtype, | |
bias_type, | |
batch_size, | |
q_len, | |
kv_len, | |
h, | |
k, | |
kv, | |
) | |
query, key, value, attn_bias = create_tensors( | |
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
fmt=fmt, | |
) | |
qkv = None | |
if ( | |
fmt == "BMHK" | |
and query.shape[3] == value.shape[3] | |
and query.shape[1] == value.shape[1] | |
): | |
qkv = torch.stack([query, key, value], 2) | |
qkv.requires_grad_(True) | |
# bm3hk -> 3 x bmhk | |
query, key, value = xformers.ops.unbind(qkv, 2) | |
assert not query.is_contiguous() | |
query.requires_grad_(True) | |
key.requires_grad_(True) | |
value.requires_grad_(True) | |
proj = torch.nn.Linear(kv, k, device=device, dtype=dtype) | |
x = query | |
for _ in range(5): | |
x = checkpoint( | |
apply_attention, | |
x, | |
key, | |
value, | |
attn_bias, | |
op, | |
proj, | |
use_reentrant=use_reentrant, | |
) | |
x.mean().backward() | |
ALL_FW_OPS_NO_SMALLK = [op for op in ALL_FW_OPS if op is not fmha.small_k.FwOp] | |
def test_unsupported_cpu(op: Type[fmha.AttentionFwOpBase]): | |
q = torch.empty([1, 1, 1, 32]) | |
with pytest.raises(ValueError): | |
fmha.memory_efficient_attention(q, q, q, op=(op, None)) | |
def test_unsupported_stride_lastdim(op: Type[fmha.AttentionFwOpBase]): | |
q = torch.empty([1, 1, 32, 4], device="cuda", dtype=torch.float16).permute( | |
0, 3, 1, 2 | |
) | |
try: | |
fmha.memory_efficient_attention(q, q, q, op=(op, None)) | |
except ValueError as e: | |
if "Only work on pre-MLIR triton for now" in str(e): | |
pytest.skip("Only work on pre-MLIR triton for now") | |
q = q.contiguous() | |
fmha.memory_efficient_attention(q, q, q, op=(op, None)) | |
def test_unsupported_stride_alignment(op: Type[fmha.AttentionFwOpBase]): | |
q = torch.empty([1, 2, 1, 33], device="cuda", dtype=torch.float16)[:, :, :, :32] | |
try: | |
fmha.memory_efficient_attention(q, q, q, op=(op, None)) | |
except ValueError as e: | |
if "Only work on pre-MLIR triton for now" in str(e): | |
pytest.skip("Only work on pre-MLIR triton for now") | |
q = q.contiguous() | |
fmha.memory_efficient_attention(q, q, q, op=(op, None)) | |
def test_unsupported_dropout_combine_flash_cutlass() -> None: | |
q = torch.empty( | |
[1, 4, 1, 16], device="cuda", dtype=torch.float16, requires_grad=True | |
) | |
with pytest.raises(ValueError): | |
out = fmha.memory_efficient_attention( | |
q, q, q, p=0.1, op=(fmha.cutlass.FwOp, fmha.flash.BwOp) | |
) | |
out.backward(out) | |
with pytest.raises(ValueError): | |
out = fmha.memory_efficient_attention( | |
q, q, q, p=0.1, op=(fmha.flash.FwOp, fmha.cutlass.BwOp) | |
) | |
out.backward(out) | |
def test_attn_bias_causal() -> None: | |
m = -math.inf | |
causal_mask = torch.tensor([[0, m], [0, 0], [0, 0]]) | |
tensor_bias = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) | |
attn_bias = fmha.attn_bias.LowerTriangularMask() | |
assert_allclose(attn_bias.materialize(causal_mask.shape), causal_mask, "causal") | |
attn_bias = attn_bias.add_bias(tensor_bias) | |
assert_allclose( | |
attn_bias.materialize(causal_mask.shape), | |
tensor_bias + causal_mask, | |
"causal+tensor_bias", | |
) | |
def test_attn_bias_torch_tensor() -> None: | |
tensor_bias = torch.tensor([[1.0, 2.0, 3.0], [3.0, 4.0, 5.0]]) | |
attn_bias = fmha.attn_bias.LowerTriangularMaskWithTensorBias(tensor_bias) | |
m = -math.inf | |
causal_bias = torch.tensor([[0, m, m], [0, 0, m]]) | |
assert_allclose( | |
attn_bias.materialize((2, 3)), causal_bias + tensor_bias, "tensor_bias+causal" | |
) | |
def test_attn_bias_blockdiag() -> None: | |
queries = [ | |
torch.randn([1, 3, 1, 8]), | |
torch.randn([1, 2, 1, 8]), | |
torch.randn([1, 5, 1, 8]), | |
] | |
attn_bias, q = fmha.BlockDiagonalMask.from_tensor_list(queries) | |
# Verify mask | |
as_tensor = attn_bias.materialize((10, 10)) | |
assert int((as_tensor != -math.inf).sum().item()) == 3 * 3 + 2 * 2 + 5 * 5 | |
assert_allclose(as_tensor[0:3, 0:3], torch.zeros([3, 3]), "batch0") | |
assert_allclose(as_tensor[3:5, 3:5], torch.zeros([2, 2]), "batch1") | |
assert_allclose(as_tensor[5:, 5:], torch.zeros([5, 5]), "batch2") | |
# Verify we can split it back | |
queries2 = attn_bias.split(q) | |
assert len(queries) == len(queries2) | |
for q1, q2 in zip(queries, queries2): | |
assert_allclose(q1, q2) | |
def test_attn_bias_blockdiag_batched() -> None: | |
queries = [ | |
torch.randn([1, 3, 1, 8]), | |
torch.randn([3, 2, 1, 8]), | |
torch.randn([1, 5, 1, 8]), | |
] | |
attn_bias, q = fmha.BlockDiagonalMask.from_tensor_list(queries) | |
# Verify mask | |
as_tensor = attn_bias.materialize((14, 14)) | |
assert int((as_tensor != -math.inf).sum().item()) == 3 * 3 + 3 * 2 * 2 + 5 * 5 | |
assert_allclose(as_tensor[0:3, 0:3], torch.zeros([3, 3]), "batch0") | |
assert_allclose(as_tensor[3:5, 3:5], torch.zeros([2, 2]), "batch1.0") | |
assert_allclose(as_tensor[5:7, 5:7], torch.zeros([2, 2]), "batch1.1") | |
assert_allclose(as_tensor[7:9, 7:9], torch.zeros([2, 2]), "batch1.2") | |
assert_allclose(as_tensor[9:, 9:], torch.zeros([5, 5]), "batch2") | |
# Verify we can split it back | |
queries2 = attn_bias.split(q) | |
assert len(queries) == len(queries2) | |
for q1, q2 in zip(queries, queries2): | |
assert_allclose(q1, q2) | |
def test_attn_bias_blockdiag_crossattn_causal() -> None: | |
# Q / KV have different seqlen | |
list_q = [ | |
torch.randn([1, 3, 1, 8]), | |
torch.randn([2, 1, 1, 8]), | |
] | |
list_k = [ | |
torch.randn([1, 2, 1, 8]), | |
torch.randn([2, 3, 1, 8]), | |
] | |
attn_bias, q, k, _ = fmha.attn_bias.BlockDiagonalMask.from_tensor_lists_qkv( | |
list_q, list_k | |
) | |
# Verify mask | |
as_tensor = attn_bias.materialize((q.shape[1], k.shape[1])) | |
assert int((as_tensor != -math.inf).sum().item()) == 3 * 2 + 2 * 3 * 1 | |
assert_allclose(as_tensor[0:3, 0:2], torch.zeros([3, 2]), "batch0") | |
assert_allclose(as_tensor[3:4, 2:5], torch.zeros([1, 3]), "batch1.0") | |
assert_allclose(as_tensor[4:, 5:], torch.zeros([1, 3]), "batch1.1") | |
# Also test causal version | |
as_tensor = attn_bias.make_causal().materialize((q.shape[1], k.shape[1])) | |
assert_allclose( | |
as_tensor[3:4, 2:5], | |
fmha.attn_bias.LowerTriangularMask().materialize((1, 3)), | |
"batch1.0[causal]", | |
) | |
# Verify we can split it back | |
list_q2 = attn_bias.split_queries(q) | |
assert len(list_q) == len(list_q2) | |
for q1, q2 in zip(list_q, list_q2): | |
assert_allclose(q1, q2) | |
with pytest.raises(ValueError): | |
attn_bias.split_queries(k) | |
list_k2 = attn_bias.split_kv(k) | |
assert len(list_k) == len(list_k2) | |
for k1, k2 in zip(list_k, list_k2): | |
assert_allclose(k1, k2) | |
def test_attn_bias_blockdiag_crossattn_causal_with_prefix_qk_cond() -> None: | |
list_q = [ | |
torch.randn([1, 3, 1, 8]), | |
] | |
list_k = [ | |
torch.randn([1, 2, 1, 8]), | |
] | |
attn_bias, q, k, _ = fmha.attn_bias.BlockDiagonalMask.from_tensor_lists_qkv( | |
list_q, list_k | |
) | |
with pytest.raises(ValueError): | |
attn_bias.make_causal_from_bottomright() | |
def test_attn_bias_blockdiag_crossattn_causal_with_prefix() -> None: | |
# Q / KV have different seqlen | |
list_q = [ | |
torch.randn([1, 2, 1, 8]), | |
torch.randn([2, 2, 1, 8]), | |
] | |
list_k = [ | |
torch.randn([1, 2, 1, 8]), | |
torch.randn([2, 5, 1, 8]), | |
] | |
attn_bias, q, k, _ = fmha.attn_bias.BlockDiagonalMask.from_tensor_lists_qkv( | |
list_q, list_k | |
) | |
as_tensor = attn_bias.make_causal_from_bottomright().materialize( | |
(q.shape[1], k.shape[1]) | |
) | |
m = -math.inf | |
assert_allclose( | |
as_tensor[0:2, 0:2], | |
torch.tensor([[0, m], [0, 0]], dtype=torch.float32), | |
"batch1.1[causal_with_prefix]", | |
) | |
assert_allclose( | |
as_tensor[2:4, 2:7], | |
torch.tensor([[0, 0, 0, 0, m], [0, 0, 0, 0, 0]], dtype=torch.float32), | |
"batch2.1[causal_with_prefix]", | |
) | |
assert_allclose( | |
as_tensor[4:6, 7:12], | |
torch.tensor([[0, 0, 0, 0, m], [0, 0, 0, 0, 0]], dtype=torch.float32), | |
"batch2.2[causal_with_prefix]", | |
) | |
def test_attn_bias_padded() -> None: | |
bsize, n_heads, d, padding = 8, 3, 8, 32 | |
torch.manual_seed(0) | |
# Q / KV have different seqlen | |
k = torch.randn((bsize, padding, n_heads, d), device="cuda", dtype=torch.float16) | |
k_seqlen = [5, 8, 7, 1, 9, 3, 12, 32] | |
other = bsize - 1 | |
v = torch.randn((bsize, padding, n_heads, d), device="cuda", dtype=torch.float16) | |
n_q_first = 4 | |
q = [ | |
torch.randn((1, n_q_first, n_heads, d), device="cuda", dtype=torch.float16), | |
torch.randn((1, other, n_heads, d), device="cuda", dtype=torch.float16), | |
] | |
q_cat = torch.cat([x.view(1, -1, n_heads, d) for x in q], dim=1) | |
q_seqlen = [n_q_first] + [1] * other | |
attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( | |
q_seqlen=q_seqlen, | |
kv_seqlen=k_seqlen, | |
kv_padding=padding, | |
) | |
v = v.view(1, -1, n_heads, d) | |
k = k.view(1, -1, n_heads, d) | |
scores = (q_cat.transpose(1, 2) @ k.transpose(1, 2).transpose(2, 3)).float() | |
assert not scores.isnan().any() | |
mask = torch.full_like(scores, -float("inf")) | |
for i, (slen, qlen) in enumerate(zip(k_seqlen, q_seqlen)): | |
kseq_start = i * padding | |
qstart = sum(q_seqlen[:i]) | |
mask[:, :, qstart : qstart + qlen, kseq_start : kseq_start + slen] = torch.triu( | |
mask[:, :, qstart : qstart + qlen, kseq_start : kseq_start + slen].float(), | |
diagonal=1 + slen - qlen, | |
).float() | |
scores += mask | |
assert not scores.isnan().any() | |
# 1,3,10,8 @ 1,3,8,256 -> 1,3,10,256 | |
scores = torch.nn.functional.softmax(scores, -1).half() | |
# torch.Size([1, 3, 3, 32]) @ torch.Size([1, 3, 32, 8]) | |
output = scores @ v.transpose(1, 2) # 1,3,10,256 @ 1,3,256, 8 -> 1,3,10,8 | |
output = output.transpose(1, 2).contiguous() | |
fmha_output = fmha.memory_efficient_attention_forward( | |
q_cat, k, v, attn_bias, scale=1.0 | |
) | |
# assert torch.allclose(output, fmha_output) | |
assert_allclose( | |
output, | |
fmha_output, | |
atol=fmha.cutlass.FwOp.ERROR_ATOL[torch.float16], | |
rtol=fmha.cutlass.FwOp.ERROR_RTOL[torch.float16], | |
) | |
def test_attn_bias_to_copy() -> None: | |
def _test_to_copy(attn_bias: torch.Tensor) -> None: | |
assert attn_bias.device.type == "cpu", f"{attn_bias.device}" | |
attn_bias_cuda = attn_bias.cuda() | |
assert attn_bias_cuda.device.type == "cuda", f"{attn_bias_cuda.device}" | |
attn_bias_fp16 = attn_bias.to(torch.float16) | |
assert attn_bias_fp16.device.type == "cpu", f"{attn_bias_fp16.device}" | |
assert attn_bias_fp16.dtype == torch.float16, f"{attn_bias_fp16.dtype}" | |
attn_bias = fmha.attn_bias.LowerTriangularMask().to("cpu") | |
_test_to_copy(attn_bias) | |
tensor_bias = torch.tensor([[1.0, 2.0, 3.0], [3.0, 4.0, 5.0]]) | |
attn_bias = fmha.attn_bias.LowerTriangularMaskWithTensorBias(tensor_bias).to("cpu") | |
_test_to_copy(attn_bias) | |
def _kv_heads_label(kv_heads: Optional[int]) -> str: | |
if kv_heads is None: | |
return "" | |
if kv_heads == 1: | |
return "mq" | |
return f"gqa{kv_heads}" | |
def test_decoder( | |
op, | |
n_heads: int, | |
kv_heads: Optional[int], | |
padding: int, | |
bsz: int, | |
dtype: str, | |
dequant: bool = False, | |
num_queries: int = 1, | |
d: int = 128, | |
) -> None: | |
# kv_heads = 1: multiquery | |
# kv_heads = None: neither MQA nor GQA | |
# kv_heads > 1: BMGHK | |
if dtype == "bf16" and compute_capability < (8, 0): | |
raise pytest.skip("BF16 is only supported on SM80+") | |
import triton | |
if dequant and triton.__version__[:4] < "3.0.": | |
raise pytest.skip("dequant needs triton updates") | |
dtype_ = {"f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32}[dtype] | |
torch.manual_seed(1) | |
if kv_heads is not None and kv_heads > 1: | |
k_shape: Tuple[int, ...] = (1, bsz * padding, kv_heads, n_heads, d) | |
q_shape: Tuple[int, ...] = ( | |
1, | |
bsz * num_queries, | |
kv_heads, | |
n_heads, | |
d, | |
) | |
else: | |
k_shape = (1, bsz * padding, n_heads, d) | |
q_shape = (1, bsz * num_queries, n_heads, d) | |
# TODO: support 2 kv heads etc. | |
k = torch.randn(k_shape, dtype=dtype_, device="cuda") | |
k_seqlen = torch.randint(num_queries, padding + 1, (bsz,)).tolist() | |
v = torch.randn(k_shape, dtype=dtype_, device="cuda") | |
q = torch.randn(q_shape, dtype=dtype_, device="cuda") | |
if dequant: | |
k_shape = k_shape[:-1] + (d // 8 + op.NUM_GROUPS,) | |
k = torch.zeros(k_shape, dtype=torch.int32, device="cuda") | |
k.random_() | |
k[..., : op.NUM_GROUPS].view(torch.float16).fill_(1.0) | |
v = torch.zeros(k_shape, dtype=torch.int32, device="cuda") | |
v.random_() | |
v[..., : op.NUM_GROUPS].view(torch.float16).fill_(1.0) | |
if kv_heads is not None: | |
k = k[..., :1, :].expand(k_shape) | |
v = v[..., :1, :].expand(k_shape) | |
if skip_reasons := op.not_supported_reasons(fmha.Inputs(q, k, v)): | |
pytest.skip("; ".join(skip_reasons)) | |
attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( | |
q_seqlen=[num_queries] * bsz, | |
kv_seqlen=k_seqlen, | |
kv_padding=padding, | |
) | |
decoder_output = fmha.memory_efficient_attention_forward( | |
q, | |
k, | |
v, | |
attn_bias, | |
op=op, | |
) | |
def dequant_cache(x): | |
x = x[..., op.NUM_GROUPS :, None].expand(k_shape[:-1] + (d // 8, 8)) | |
x = x // (2 ** (4 * torch.arange(8, device="cuda"))) | |
x = (x % 16).flatten(start_dim=-2) | |
return x.to(dtype_) + 1.0 | |
if dequant: | |
k = dequant_cache(k) | |
v = dequant_cache(v) | |
ref_output = ref_attention_for_test(q, k, v, attn_bias) | |
assert_allclose( | |
decoder_output.to(ref_output.dtype), | |
ref_output, | |
atol=op.ERROR_ATOL[dtype_] * 4, | |
rtol=op.ERROR_RTOL[dtype_], | |
) | |
def test_triton_splitk_decoder( | |
op, | |
dequant: bool, | |
kv_heads: Optional[int], | |
n_heads: int, | |
padding: int, | |
bsz: int, | |
dtype: str, | |
) -> None: | |
# We omit dequant with f16: it needs a very high tol | |
test_decoder( | |
op, | |
kv_heads=kv_heads, | |
n_heads=n_heads, | |
padding=padding, | |
bsz=bsz, | |
dtype=dtype, | |
dequant=dequant, | |
) | |
def test_ck_splitk_decoder( | |
op, | |
kv_heads: Optional[int], | |
n_heads: int, | |
padding: int, | |
bsz: int, | |
dtype: str, | |
d: int, | |
) -> None: | |
# no quantized impl compared to cuda | |
test_decoder( | |
op, | |
kv_heads=kv_heads, | |
n_heads=n_heads, | |
padding=padding, | |
bsz=bsz, | |
dtype=dtype, | |
d=d, | |
) | |
# n_heads=1 => it's ambiguous whether can count as multiquery | |
def test_triton_splitk_decoder_manyqueries( | |
op, | |
multiquery: bool, | |
n_heads: int, | |
padding: int, | |
bsz: int, | |
dtype: str, | |
num_queries: int, | |
) -> None: | |
kv_heads = 1 if multiquery else None | |
test_decoder( | |
op, | |
kv_heads=kv_heads, | |
n_heads=n_heads, | |
padding=padding, | |
bsz=bsz, | |
dtype=dtype, | |
num_queries=num_queries, | |
dequant=False, | |
) | |
def test_attn_bias_from_seqlens() -> None: | |
bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens([3, 5, 1]) | |
out = bias.split(torch.randn([1, 3 + 5 + 1, 16])) | |
assert len(out) == 3 | |
assert tuple(out[0].shape) == (1, 3, 16) | |
def test_attn_bias_blockdiag_doc() -> None: | |
"""IMPORTANT: | |
This is the example in the doc for `BlockDiagonalMask`. | |
If this example needs to be updated, please also update the doc | |
""" | |
import torch | |
from xformers.ops import fmha | |
if torch.version.hip: | |
pytest.skip("backward pass/gradience is not yet supported by ck-tiled fmha!") | |
K = 16 | |
dtype = torch.float16 | |
device = "cuda" | |
list_x = [ | |
torch.randn([1, 3, 1, K], dtype=dtype, device=device), | |
torch.randn([1, 6, 1, K], dtype=dtype, device=device), | |
torch.randn([1, 2, 1, K], dtype=dtype, device=device), | |
] | |
attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x) | |
linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype) # type: ignore | |
q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2) | |
out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias) | |
list_out = attn_bias.split(out) | |
assert tuple(list_out[0].shape) == (1, 3, 1, K) | |
class TestAttnBias: | |
def create_tensors( | |
dtype, | |
B: int = 2, | |
Mq: int = 32, | |
Mkv: int = 32, | |
H: int = 3, | |
K: int = 16, | |
Kv: int = 16, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
return ( | |
torch.randn([B, Mq, H, K], device="cuda", dtype=dtype) * 3, | |
torch.randn([B, Mkv, H, K], device="cuda", dtype=dtype) * 3, | |
torch.randn([B, Mkv, H, Kv], device="cuda", dtype=dtype) * 3, | |
torch.randn([B, H, Mq, Mkv], device="cuda", dtype=dtype) * 3, | |
) | |
def pad_bias(bias: torch.Tensor) -> torch.Tensor: | |
align_to = 16 | |
if (bias.shape[-1] % align_to) == 0: | |
return bias | |
pad_count = align_to - (bias.shape[-1] % align_to) | |
return torch.nn.functional.pad(bias, [0, pad_count])[:, :, :, : bias.shape[-1]] | |
def test_f16_biasf32(self) -> None: | |
q, k, v, bias = self.create_tensors(torch.float16) | |
fmha.memory_efficient_attention(q, k, v, attn_bias=bias) | |
bias = bias.to(torch.float32) | |
with pytest.raises((ValueError, RuntimeError)): | |
fmha.memory_efficient_attention(q, k, v, attn_bias=bias) | |
def test_f32_biasf16(self) -> None: | |
q, k, v, bias = self.create_tensors(torch.float32) | |
fmha.memory_efficient_attention(q, k, v, attn_bias=bias) | |
bias = bias.to(torch.float16) | |
with pytest.raises((ValueError, RuntimeError)): | |
fmha.memory_efficient_attention(q, k, v, attn_bias=bias) | |
def test_wrong_alignment(self, dtype) -> None: | |
op = fmha.cutlass.FwOp if torch.version.cuda else fmha.ck.FwOp | |
if dtype not in op.SUPPORTED_DTYPES: | |
pytest.skip( | |
f"{dtype=} is not supported by {op.__module__}.{op.__qualname__}" | |
) | |
q, k, v, bias = self.create_tensors(dtype, Mq=7, Mkv=5) | |
try: | |
fmha.memory_efficient_attention(q, k, v, attn_bias=bias, op=(op, None)) | |
return | |
except (ValueError, RuntimeError): | |
pass | |
# This case is not supported, likely due to padding issues | |
# Let's make sure it works with padding | |
assert bias.ndim == 4, bias.shape | |
bias_padded = self.pad_bias(bias) | |
out = fmha.memory_efficient_attention( | |
q, k, v, attn_bias=bias_padded, op=(op, None) | |
).float() | |
ref_out = ref_attention_bmhk_for_test(q, k, v, bias) | |
assert_allclose( | |
out, ref_out, atol=op.ERROR_ATOL[dtype], rtol=op.ERROR_RTOL[dtype] | |
) | |
def test_permuted_attn_bias(self) -> None: | |
op = fmha.cutlass.FwOp | |
dtype = torch.float16 | |
q, k, v, bias = self.create_tensors(dtype, Mq=7, Mkv=7) | |
bias = bias.transpose(-1, -2) # now `stride(-1) != 1` | |
# Either it works, or it raises an exception | |
# but we should never get a CUDA error | |
try: | |
out = fmha.memory_efficient_attention( | |
q, k, v, attn_bias=bias, op=(op, None) | |
).float() | |
ref_out = ref_attention_bmhk_for_test(q, k, v, bias) | |
assert_allclose( | |
out, ref_out, atol=op.ERROR_ATOL[dtype], rtol=op.ERROR_RTOL[dtype] | |
) | |
except (ValueError, RuntimeError): | |
pass | |
SM_AND_SHMEM_KBYTES = [ | |
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications-technical-specifications-per-compute-capability | |
(50, 64), | |
(60, 64), | |
(70, 96), | |
(75, 64), | |
(80, 163), | |
(86, 99), | |
(89, 99), | |
# (90, 227), | |
] | |
def test_has_kernel_for(sm_shmem: Tuple[int, int], dtype_str: str) -> None: | |
dtype = {"f32": torch.float, "f16": torch.half, "bf16": torch.bfloat16}[dtype_str] | |
sm, shmem_kbytes = sm_shmem | |
if sm < 80 and dtype_str == "bf16": | |
return | |
for k in [16, 32, 64, 128, 256]: | |
assert torch.ops.xformers._has_cutlassF_kernel_for( | |
dtype, sm, shmem_kbytes * 1024, k | |
), f"k={k}" | |
assert torch.ops.xformers._has_cutlassB_kernel_for( | |
dtype, sm, shmem_kbytes * 1024, k | |
), f"k={k}" | |
def test_window_size_materialize() -> None: | |
seqlens = [4, 6] | |
attn_bias = fmha.attn_bias.BlockDiagonalMask.from_seqlens( | |
q_seqlen=seqlens, | |
kv_seqlen=seqlens, | |
).make_local_attention(2) | |
mask = attn_bias.materialize( | |
(1, 1, sum(seqlens), sum(seqlens)), | |
device="cpu", | |
dtype=torch.float32, | |
) | |
true_mask = torch.log( | |
torch.Tensor( | |
[ | |
[ | |
[ | |
[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
[0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
[0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], | |
[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], | |
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0], | |
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0], | |
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0], | |
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], | |
] | |
] | |
] | |
) | |
) | |
assert torch.all(mask == true_mask) | |
def test_forward_gqa(opFW_biasT, Mq: int): | |
opFW, biasT = opFW_biasT | |
if Mq < 512 and ( | |
issubclass(biasT, fmha.attn_bias.LowerTriangularMask) | |
or issubclass(biasT, fmha.attn_bias.BlockDiagonalCausalMask) | |
): | |
pytest.skip("undefined upper left") | |
B_Mq_Mkv_H_K_Kv = (3, Mq, 512, 16, 128, 128) | |
test_forward( | |
( | |
opFW, | |
"cuda", | |
torch.float16, | |
biasT, | |
*B_Mq_Mkv_H_K_Kv, | |
), | |
packed=False, | |
fmt="BMGHK", | |
g=2, | |
) | |
def test_backward_gqa(opBW): | |
H = 8 | |
B_Mq_Mkv_H_K_Kv = (3, 512, 512, H, 128, 128) | |
dtype = torch.float16 | |
query, key, value, attn_bias = create_tensors( | |
*(opBW, "cuda", dtype, type(None), *B_Mq_Mkv_H_K_Kv), | |
attn_bias_requires_grad=False, | |
fmt="BMHK", | |
) | |
op = (fmha.cutlass.FwOp, opBW) | |
key = key[:, :, :1].expand(-1, -1, H, -1) | |
value = value[:, :, :1].expand(-1, -1, H, -1) | |
key.requires_grad_(True) | |
out = fmha.memory_efficient_attention(query, key, value, attn_bias=attn_bias) | |
out_ref = ref_attention_bmhk_for_test(query, key, value, attn_bias=attn_bias) | |
assert_allclose( | |
out.float(), | |
out_ref.float(), | |
atol=op[0].ERROR_ATOL[dtype], | |
rtol=op[0].ERROR_RTOL[dtype], | |
) | |
out.backward(query) | |
dk = key.grad | |
key.grad = None | |
out_ref.backward(query) | |
assert_allclose( | |
dk.float(), | |
key.grad.float(), | |
atol=op[1].ERROR_ATOL[dtype], | |
rtol=op[1].ERROR_RTOL[dtype], | |
) | |
def test_forward_gqa_one_group(opFW): | |
dtype = torch.float16 | |
B, Mq, Mkv, H, K = 3, 13, 16, 5, 128 | |
q = torch.randn([B, Mq, 1, H, K], dtype=dtype, device="cuda") * 3 | |
k = torch.randn([B, Mkv, 1, H, K], dtype=dtype, device="cuda") * 3 | |
v = torch.randn([B, Mkv, 1, H, K], dtype=dtype, device="cuda") * 3 | |
supported = opFW.supports(fmha.Inputs(q, k, v)) | |
if not supported: | |
supported_bmhk = opFW.supports(fmha.Inputs(q[:, :, 0], k[:, :, 0], v[:, :, 0])) | |
assert supported == supported_bmhk | |
pytest.skip("not supported") | |
out = fmha.memory_efficient_attention_forward(q, k, v, op=opFW) | |
ref = ref_attention_for_test(q, k, v) | |
assert_allclose( | |
out.float(), | |
ref, | |
atol=opFW.ERROR_ATOL[dtype], | |
rtol=opFW.ERROR_RTOL.get(dtype, 1e-5), | |
) | |
def test_flash_gqa_wrong_strides() -> None: | |
op = (fmha.flash.FwOp, None) | |
device = "cuda" | |
B, Mq, Mkv, G, H, K = 3, 1, 512, 2, 8, 128 | |
q = torch.empty((B, Mq, G, H, K), dtype=torch.float16, device=device) | |
kv = torch.empty((B, Mkv, G, H, K), dtype=torch.float16, device=device) | |
fmha.memory_efficient_attention(q, kv, kv, op=op) | |
kv = torch.empty((B, Mkv, H, G, K), dtype=torch.float16, device=device).permute( | |
0, 1, 3, 2, 4 | |
) | |
with pytest.raises(ValueError): | |
fmha.memory_efficient_attention(q, kv, kv, op=op) | |
kv = torch.empty((B, Mkv, G, 1, K), dtype=torch.float16, device=device) | |
with pytest.raises(ValueError): | |
fmha.memory_efficient_attention(q, kv, kv, op=op) | |
kv = kv.expand(-1, -1, -1, H, K) | |
fmha.memory_efficient_attention(q, kv, kv, op=op) | |
kv = torch.empty((B, Mkv, G, H, 2 * K), dtype=torch.float16, device=device)[ | |
:, :, :, :, :K | |
] | |
fmha.memory_efficient_attention(q, kv, kv, op=op) | |
def _dispatches_to_splitK(q, kv): | |
return ( | |
_dispatch_fw_priority_list(fmha.Inputs(q, kv, kv), False)[0] | |
is fmha.triton_splitk.FwOp | |
) | |
def _dispatches_to_flash_decoding(q, kv): | |
return ( | |
_dispatch_fw_priority_list(fmha.Inputs(q, kv, kv), False)[0] is fmha.flash.FwOp | |
) | |
def test_dispatch_decoding_bmhk() -> None: | |
assert not _dispatches_to_splitK( | |
torch.empty([1, 8, 1, 128]), torch.empty([1, 2048, 1, 128]) | |
), "Should not use SplitK with 1 head (no tensorcores)" | |
assert _dispatches_to_flash_decoding( | |
torch.empty([1, 8, 32, 128]), | |
torch.empty([1, 2048, 1, 128]).expand(-1, -1, 32, -1), | |
), "Should use Flash-Decoding with BMHK MQA" | |
assert not _dispatches_to_splitK( | |
torch.empty([1, 8, 32, 128]), | |
torch.empty([1, 2048, 32, 128]), | |
), "Should not use SplitK when no TensorCores" | |
assert not _dispatches_to_splitK( | |
torch.empty([1, 128, 32, 128]), | |
torch.empty([1, 2048, 1, 128]).expand(-1, -1, 32, -1), | |
), "Should not use SplitK if q seqlen is long" | |
assert not _dispatches_to_splitK( | |
torch.empty([128, 8, 32, 128]), | |
torch.empty([128, 2048, 1, 128]).expand(-1, -1, 32, -1), | |
), "Should not use SplitK if B is big" | |
def test_dispatch_decoding_bmghk() -> None: | |
assert not _dispatches_to_splitK( | |
torch.empty([1, 8, 1, 1, 128]), torch.empty([1, 2048, 1, 1, 128]) | |
), "Should not use SplitK with 1 head (no tensorcores)" | |
assert _dispatches_to_flash_decoding( | |
torch.empty([1, 8, 1, 32, 128]), | |
torch.empty([1, 2048, 1, 1, 128]).expand(-1, -1, -1, 32, -1), | |
), "Should use Flash-Decoding with MQA" | |
assert _dispatches_to_flash_decoding( | |
torch.empty([1, 8, 4, 32, 128]), | |
torch.empty([1, 2048, 4, 1, 128]).expand(-1, -1, -1, 32, -1), | |
), "Should use Flash-Decoding with GQA" | |
assert not _dispatches_to_splitK( | |
torch.empty([1, 8, 1, 32, 128]), | |
torch.empty([1, 2048, 1, 32, 128]), | |
), "Should not use SplitK when no TensorCores" | |
assert not _dispatches_to_splitK( | |
torch.empty([1, 128, 1, 32, 128]), | |
torch.empty([1, 2048, 1, 1, 128]).expand(-1, -1, -1, 32, -1), | |
), "Should not use SplitK if q seqlen is long" | |
assert not _dispatches_to_splitK( | |
torch.empty([128, 8, 1, 32, 128]), | |
torch.empty([128, 2048, 1, 1, 128]).expand(-1, -1, -1, 32, -1), | |
), "Should not use SplitK if B is big" | |
shapes_triton_splitk = [ | |
(1, 8, 2**16, 1, 128, 128), | |
(1, 4, 2**16, 1, 128, 128), | |
(1, 16, 2**16, 1, 128, 128), | |
(1, 16, 2**16, 1, 32, 32), | |
(1, 8, 1025, 1, 128, 128), | |
(2, 8, 4096, 1, 128, 128), | |
(10, 8, 2**16, 1, 128, 128), | |
(10, 15, 2**16, 1, 128, 128), | |
(1, 3, 2**16, 1, 128, 128), | |
(1, 3, 2**16 - 10, 1, 128, 128), | |
(2, 3, 73, 1, 128, 128), | |
(2, 7, 7328, 1, 128, 128), | |
(2, 7, 7328, 1, 120, 120), | |
(2, 7, 63, 1, 120, 120), | |
] | |
op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv_splitk = [ | |
(fmha.triton_splitk.FwOp, "cuda", torch.float16, type(None), *s) | |
for s in shapes_triton_splitk | |
] + [ | |
(fmha.triton_splitk.FwOp, "cuda", torch.bfloat16, type(None), *s) | |
for s in shapes_triton_splitk | |
] | |
def test_forward_splitk( | |
opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
packed=False, | |
fmt="BMHK", | |
): | |
test_forward(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, packed=packed, fmt=fmt) | |
def test_mqa_decoding(op: Type[fmha.AttentionFwOpBase], dtype, B_Mkv_H_K): | |
B, Mkv, H, K = B_Mkv_H_K | |
q = torch.randn([B, 1, H, K], dtype=dtype, device="cuda") * 3 | |
k = torch.randn([B, Mkv, 1, K], dtype=dtype, device="cuda") * 3 | |
v = torch.randn([B, Mkv, 1, K], dtype=dtype, device="cuda") * 3 | |
k = k.expand(-1, -1, H, -1) | |
v = v.expand(-1, -1, H, -1) | |
if skip_reasons := op.not_supported_reasons(fmha.Inputs(q, k, v)): | |
pytest.skip("; ".join(skip_reasons)) | |
out = fmha.memory_efficient_attention_forward(q, k, v, op=op) | |
ref = ref_attention_for_test(q, k, v) | |
assert_allclose( | |
out.float(), | |
ref, | |
atol=op.ERROR_ATOL[dtype], | |
rtol=op.ERROR_RTOL.get(dtype, 1e-5), | |
) | |
def test_empty_tensors_empty_query( | |
opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
): | |
query, key, value, attn_bias = create_tensors( | |
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
fmt="BMHK", | |
) | |
opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0] | |
if torch.version.hip: | |
pytest.skip("backward pass/gradience is not yet supported by ck-tiled fmha!") | |
query = query[:, :0] | |
query.requires_grad_(True) | |
key.requires_grad_(True) | |
value.requires_grad_(True) | |
out = xformers.ops.memory_efficient_attention(query, key, value, op=(opFW, None)) | |
assert out.shape[1] == 0 | |
out.backward(out) | |
# dK/dV should be all zeros | |
assert_allclose(key.grad, torch.zeros_like(key.grad), "key.grad") | |
assert_allclose(value.grad, torch.zeros_like(value.grad), "value.grad") | |
def test_empty_tensors_empty_kv( | |
opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
): | |
query, key, value, attn_bias = create_tensors( | |
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
fmt="BMHK", | |
) | |
opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0] | |
if opFW == fmha.triton_splitk.FwOp: | |
pytest.skip("triton_splitk doesn't support empty kv") | |
if torch.version.hip: | |
pytest.skip("backward pass/gradience is not yet supported by ck-tiled fmha!") | |
key = key[:, :0] | |
value = value[:, :0] | |
query.requires_grad_(True) | |
key.requires_grad_(True) | |
value.requires_grad_(True) | |
out = xformers.ops.memory_efficient_attention(query, key, value, op=(opFW, None)) | |
assert_allclose(out, torch.zeros_like(out), "out") | |
out.backward(out) | |
# dQ should be all zeros | |
assert_allclose(query.grad, torch.zeros_like(query.grad), "query.grad") | |
def test_empty_tensors_empty_b( | |
opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
): | |
query, key, value, attn_bias = create_tensors( | |
*opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, | |
fmt="BMHK", | |
) | |
opFW = opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv[0] | |
if torch.version.hip: | |
pytest.skip("backward pass/gradience is not yet supported by ck-tiled fmha!") | |
query, key, value = query[:0], key[:0], value[:0] | |
query.requires_grad_(True) | |
key.requires_grad_(True) | |
value.requires_grad_(True) | |
out = xformers.ops.memory_efficient_attention(query, key, value, op=(opFW, None)) | |
out.backward(out) | |
def test_local_attn_bias() -> None: | |
mask = ( | |
fmha.attn_bias.LocalAttentionFromBottomRightMask(window_left=1, window_right=2) | |
.materialize(shape=(4, 4)) | |
.exp() | |
) | |
expected = torch.tensor( | |
[[1, 1, 1, 0], [1, 1, 1, 1], [0, 1, 1, 1], [0, 0, 1, 1]], dtype=torch.float32 | |
) | |
assert (mask == expected).all().item() | |
def test_cutlassB_iter_order( | |
dtype, | |
cc: int, | |
maxK: int, | |
num_queries: int, | |
num_keys: int, | |
custom_mask_type, | |
window_size, | |
) -> None: | |
""" | |
This tests some internals of the cutlassB kernel | |
We test the iteration across blocks of [queries, keys] to ensure | |
that we correctly: | |
* Iterate over all the blocks that should be iterated | |
* Do *not* iterate over blocks that are completely masked out | |
* Correctly compute the number of parallel blocks that will compute | |
the same block of dQ | |
.. and we test this across variable causal masks+local attention combinations | |
""" | |
if ( | |
window_size > 0 | |
and custom_mask_type == fmha.cutlass._CustomMaskType.NoCustomMask | |
): | |
pytest.skip("LocalAttention is only supported for causal") | |
get_iteration_data = partial( | |
torch.ops.xformers._cutlassB_iteration_data, | |
dtype=dtype, | |
cc=cc, | |
maxK=maxK, | |
num_queries=num_queries, | |
num_keys=num_keys, | |
custom_mask_type=custom_mask_type, | |
window_size=window_size, | |
) | |
bias = torch.zeros([num_queries, num_keys], dtype=torch.float32) | |
if custom_mask_type != fmha.cutlass._CustomMaskType.NoCustomMask: | |
bias = fmha.attn_bias._materialize_causal_mask( | |
(num_queries, num_keys), | |
dtype=torch.float32, | |
device="cpu", | |
window_size=None if window_size == 0 else window_size, | |
from_bottomright=( | |
custom_mask_type == fmha.cutlass._CustomMaskType.CausalFromBottomRight | |
), | |
) | |
block_queries, block_keys = get_iteration_data()[:2] | |
mask_pooled = ( | |
F.max_pool2d(bias.unsqueeze(0), (block_queries, block_keys), ceil_mode=True) | |
== 0 | |
).int()[0] | |
attn_computed = torch.zeros_like(mask_pooled) | |
for key_start in range(0, num_keys, block_keys): | |
it = 0 | |
new_key_start = key_start | |
new_query_start = get_iteration_data(key_start=key_start)[2] | |
try: | |
expected_first_query = ( | |
mask_pooled[:, key_start // block_keys].tolist().index(1) | |
* block_queries | |
) | |
assert ( | |
new_query_start == expected_first_query | |
), f"Wrong first query for K={key_start}: {new_query_start} (expected {expected_first_query})" | |
except ValueError: # Nothing to compute in this column | |
pass | |
while new_key_start == key_start and new_query_start < num_queries: | |
query_start = new_query_start | |
attn_computed[query_start // block_queries, key_start // block_keys] += 1 | |
# print(f"Compute [{query_start}, {key_start}]") | |
# Is there something to compute here? | |
assert mask_pooled[ | |
query_start // block_queries, key_start // block_keys | |
].item(), "Computing a block that is not needed!" | |
new_query_start, new_key_start = get_iteration_data( | |
key_start=key_start, query_start=query_start | |
)[3:5] | |
it += 1 | |
assert it < num_queries, "" | |
assert (attn_computed == mask_pooled)[ | |
:, key_start // block_keys | |
].all(), "some blocks were not computed!" | |
# Now check that the number returned by `getNumParallelBlocksForQuery` is correct | |
for query_start in range(0, num_queries, block_queries): | |
num_parallel_blocks = get_iteration_data( | |
query_start=query_start, num_splits_key=num_keys | |
)[5] | |
num_actual = mask_pooled[query_start // block_queries].sum().item() | |
assert num_parallel_blocks == num_actual | |
def test_paged_attention( | |
B, | |
MAX_T: int, | |
num_quant_groups: int, | |
page_size: int, | |
op: Type[AttentionFwOpBase], | |
gappy: bool, | |
): | |
paged_attention_run_inner( | |
B, MAX_T, num_quant_groups, page_size, op, bench=False, gappy=gappy | |
) | |
def test_paged_attention_flash(B, MAX_T: int, page_size: int): | |
# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged | |
op = fmha.flash.FwOp | |
if ( | |
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask | |
not in op.SUPPORTED_ATTN_BIAS_TYPES | |
): | |
pytest.skip("Not supported bias") | |
num_quant_groups = 0 | |
paged_attention_run_inner(B, MAX_T, num_quant_groups, page_size, op, bench=False) | |
def paged_attention_run_inner( | |
B: int, | |
MAX_T: int, | |
num_quant_groups: int, | |
page_size: int, | |
op: Type[AttentionFwOpBase], | |
bench: bool, | |
gappy: bool = False, | |
) -> None: | |
import triton | |
torch.manual_seed(10) | |
TEST_WARMUP_MS = 500 | |
TEST_RUN_MS = 5000 | |
N_H_L = 8 | |
N_KVH_L = 1 | |
D_H = 128 | |
D_H_KV = D_H // 8 + num_quant_groups if num_quant_groups else D_H | |
kv_seqlens = torch.randint(low=1, high=MAX_T + 1, size=(B,)).tolist() | |
# Paged attention requires k.shape[1] and v.shape[1] to be divisible by page_size, so pad | |
padded_per_row_len = ((MAX_T + page_size - 1) // page_size) * page_size | |
if gappy: | |
make_paged_kwargs = { | |
"paged_type": fmha.attn_bias.PagedBlockDiagonalGappyKeysMask, | |
"notional_padding": MAX_T, | |
} | |
attn_bias = fmha.attn_bias.BlockDiagonalGappyKeysMask.from_seqlens( | |
q_seqlen=[1] * B, | |
kv_seqstarts=list(range(0, MAX_T * (B + 1), MAX_T)), | |
kv_seqlen=kv_seqlens, | |
) | |
else: | |
make_paged_kwargs = { | |
"paged_type": fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, | |
} | |
block_type = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask | |
attn_bias = block_type.from_seqlens( # type: ignore | |
q_seqlen=[1] * B, | |
kv_padding=MAX_T, | |
kv_seqlen=kv_seqlens, | |
) | |
q = torch.randn((B, 1, N_H_L, D_H), dtype=torch.bfloat16, device="cuda") | |
if num_quant_groups: | |
if triton.__version__[:4] < "3.0.": | |
raise pytest.skip("dequant needs triton updates") | |
# Using high=64 below, because with 256 both paged and non-paged paths | |
# will produce NaNs - probably some quantization coeffitions are NaNs | |
# after the bitwise cast. | |
cache_k = torch.randint( | |
0, 64, (B, MAX_T, N_KVH_L, D_H_KV * 4), dtype=torch.uint8, device="cuda" | |
) | |
cache_k = cache_k.view(dtype=torch.int32) | |
cache_v = torch.randint( | |
0, 64, (B, MAX_T, N_KVH_L, D_H_KV * 4), dtype=torch.uint8, device="cuda" | |
) | |
cache_v = cache_v.view(dtype=torch.int32) | |
op = type( | |
f"{op.__name__}_{num_quant_groups}", | |
(op,), | |
{"NUM_GROUPS": num_quant_groups}, | |
) | |
else: | |
cache_k = torch.randn( | |
(B, MAX_T, N_KVH_L, D_H), dtype=torch.bfloat16, device="cuda" | |
) | |
cache_v = torch.randn_like(cache_k) | |
axq = q.view(1, B * 1, N_H_L, D_H) | |
axk = cache_k.view(1, B * MAX_T, N_KVH_L, D_H_KV).expand( | |
1, B * MAX_T, N_H_L, D_H_KV | |
) | |
axv = cache_v.view(1, B * MAX_T, N_KVH_L, D_H_KV).expand( | |
1, B * MAX_T, N_H_L, D_H_KV | |
) | |
k_cache_size_usual = axk.numel() | |
# First, create "wasteful" K/V cache, where every block in logical cache | |
# has a physical representation, even if there's nothing stored there | |
block_tables = torch.arange( | |
B * padded_per_row_len // page_size, device="cuda", dtype=torch.int32 | |
).reshape(B, -1) | |
shape_padded = (B, padded_per_row_len, N_KVH_L, D_H_KV) | |
axk_padded = torch.empty(shape_padded, device=axk.device, dtype=axk.dtype) | |
axv_padded = torch.empty(shape_padded, device=axv.device, dtype=axv.dtype) | |
axk_padded[:, :MAX_T] = axk.view(B, -1, N_H_L, D_H_KV)[:, :, :1, :] | |
axv_padded[:, :MAX_T] = axv.view(B, -1, N_H_L, D_H_KV)[:, :, :1, :] | |
axk_padded = axk_padded.view(1, B * padded_per_row_len, N_KVH_L, D_H_KV) | |
axv_padded = axv_padded.view(1, B * padded_per_row_len, N_KVH_L, D_H_KV) | |
axk_padded = axk_padded.expand(-1, -1, N_H_L, -1) | |
axv_padded = axv_padded.expand(-1, -1, N_H_L, -1) | |
attn_bias_paged = attn_bias.make_paged( | |
block_tables=block_tables, page_size=page_size, **make_paged_kwargs # type: ignore | |
) | |
y_usual = fmha.memory_efficient_attention_forward( | |
axq, | |
axk, | |
axv, | |
attn_bias, | |
op=op, | |
) | |
if bench: | |
g = torch.cuda.CUDAGraph() | |
with torch.cuda.graph(g): | |
y_usual = fmha.memory_efficient_attention_forward( | |
axq, | |
axk, | |
axv, | |
attn_bias, | |
op=op, | |
) | |
t_ms = triton.testing.do_bench( | |
lambda g=g: g.replay(), | |
warmup=TEST_WARMUP_MS, | |
rep=TEST_RUN_MS, | |
) | |
logger.info(f"Non-paged attention took {t_ms * 1e3:.2f}us") | |
y_wasteful = fmha.memory_efficient_attention_forward( | |
axq, | |
axk_padded, | |
axv_padded, | |
attn_bias_paged, | |
op=op, | |
) | |
if bench: | |
g = torch.cuda.CUDAGraph() | |
with torch.cuda.graph(g): | |
y_wasteful = fmha.memory_efficient_attention_forward( | |
axq, | |
axk_padded, | |
axv_padded, | |
attn_bias_paged, | |
op=op, | |
) | |
t_ms = triton.testing.do_bench( | |
lambda g=g: g.replay(), | |
warmup=TEST_WARMUP_MS, | |
rep=TEST_RUN_MS, | |
) | |
logger.info(f"Paged attention with wasteful K/V-cache took {t_ms * 1e3:.2f}us") | |
torch.testing.assert_close( | |
y_wasteful, | |
y_usual, | |
atol=1.0e-2, | |
rtol=1.0e-2, | |
) | |
# Now let's create a "packed" K/V cache, where only meaniningful logical blocks are mapped to physical blocks | |
(block_tables, packed_cache_k, packed_cache_v) = pack_kv_cache( | |
cache_k, | |
cache_v, | |
kv_seqlens, | |
page_size, | |
) | |
attn_bias_paged = attn_bias.make_paged( | |
block_tables=block_tables, page_size=page_size, **make_paged_kwargs # type: ignore | |
) | |
axk = packed_cache_k.view(1, -1, N_KVH_L, D_H_KV).expand(1, -1, N_H_L, D_H_KV) | |
axv = packed_cache_v.view(1, -1, N_KVH_L, D_H_KV).expand(1, -1, N_H_L, D_H_KV) | |
k_cache_size_packed = axk.numel() | |
y_packed = fmha.memory_efficient_attention_forward( | |
axq, | |
axk, | |
axv, | |
attn_bias_paged, | |
op=op, | |
) | |
logger.info( | |
f"KV-cache size reduced by {(100 * (1 - k_cache_size_packed/k_cache_size_usual)):.2f}%" | |
) | |
torch.testing.assert_close(y_wasteful, y_packed) | |
# Let's swap two blocks, and adjust two corresponding entries in the block table. The result shouldn't change | |
i, j = 0, axk.shape[1] // page_size - 1 | |
axk = axk[:, :, :1, :] | |
axv = axv[:, :, :1, :] | |
vals_i = axk[:, i * page_size : (i + 1) * page_size, :, :].clone() | |
vals_j = axk[:, j * page_size : (j + 1) * page_size, :, :].clone() | |
axk[:, i * page_size : (i + 1) * page_size, :, :] = vals_j | |
axk[:, j * page_size : (j + 1) * page_size, :, :] = vals_i | |
vals_i = axv[:, i * page_size : (i + 1) * page_size, :, :].clone() | |
vals_j = axv[:, j * page_size : (j + 1) * page_size, :, :].clone() | |
axv[:, i * page_size : (i + 1) * page_size, :, :] = vals_j | |
axv[:, j * page_size : (j + 1) * page_size, :, :] = vals_i | |
axk = axk.expand(-1, -1, N_H_L, -1) | |
axv = axv.expand(-1, -1, N_H_L, -1) | |
where_i = block_tables == i | |
where_j = block_tables == j | |
block_tables.masked_fill_(where_i, j) | |
block_tables.masked_fill_(where_j, i) | |
y_swapped = fmha.memory_efficient_attention_forward( | |
axq, | |
axk, | |
axv, | |
attn_bias_paged, | |
op=op, | |
) | |
if bench: | |
g = torch.cuda.CUDAGraph() | |
with torch.cuda.graph(g): | |
y_swapped = fmha.memory_efficient_attention_forward( | |
axq, | |
axk, | |
axv, | |
attn_bias_paged, | |
op=op, | |
) | |
t_ms = triton.testing.do_bench( | |
lambda g=g: g.replay(), | |
warmup=TEST_WARMUP_MS, | |
rep=TEST_RUN_MS, | |
) | |
logger.info(f"Paged attention with packed K/V-cache took {t_ms * 1e3:.2f}us") | |
torch.testing.assert_close(y_swapped, y_packed) | |
def test_merge_attentions_nobias( | |
write_lse: bool, | |
stack_inputs: bool, | |
op: Type[AttentionFwOpBase], | |
G: Optional[int], | |
H: int, | |
): | |
""" | |
Merging the same attention twice shouldn't change anything. | |
This also tests the shape of the lse output of each permitted op. | |
""" | |
B, M, Mq, K = 13, 5, 3, 128 | |
if op is None or torch.bfloat16 in op.SUPPORTED_DTYPES: | |
dtype = torch.bfloat16 | |
else: | |
dtype = next(iter(op.SUPPORTED_DTYPES)) | |
if G is None: | |
q = 3 * torch.rand(B, Mq, H, K, dtype=dtype, device="cuda") | |
k = (3 * torch.rand(B, M, 1, K, dtype=dtype, device="cuda")).expand(B, M, H, K) | |
v = (3 * torch.rand(B, M, 1, K, dtype=dtype, device="cuda")).expand(B, M, H, K) | |
else: | |
q = 3 * torch.rand(B, Mq, G, H, K, dtype=dtype, device="cuda") | |
k = (3 * torch.rand(B, M, G, 1, K, dtype=dtype, device="cuda")).expand( | |
B, M, G, H, K | |
) | |
v = (3 * torch.rand(B, M, G, 1, K, dtype=dtype, device="cuda")).expand( | |
B, M, G, H, K | |
) | |
out1, lse1 = fmha.memory_efficient_attention_partial(q, k, v, op=op) | |
assert out1.shape == q.shape | |
M_ceil = lse1.shape[-1] | |
assert M_ceil >= Mq | |
assert lse1.shape == (B, H, M_ceil) if G is None else (B, G, H, M_ceil) | |
lse1 = lse1[..., :Mq] | |
attn_chunks = [out1, out1] | |
lse_chunks = [lse1, lse1] | |
attn_chunks_ = torch.stack(attn_chunks) if stack_inputs else attn_chunks | |
lse_chunks_ = torch.stack(lse_chunks) if stack_inputs else lse_chunks | |
out, lse = fmha.merge_attentions(attn_chunks_, lse_chunks_, write_lse=write_lse) # type: ignore | |
assert out.shape == out1.shape | |
assert_allclose(out1, out, rtol=1e-3, atol=1e-3, msg="out") | |
if write_lse: | |
assert lse is not None | |
assert lse.shape[:-1] == lse1.shape[:-1] | |
assert_allclose( | |
lse1[..., :Mq] + math.log(2), lse[..., :Mq], rtol=1e-3, atol=1e-3, msg="lse" | |
) | |
else: | |
assert lse is None | |
def test_merge_attentions_nobias_bwd( | |
op: Union[Type[AttentionFwOpBase], fmha.AttentionOp] | |
): | |
B, M, Mq, H, K = 13, 5, 5, 4, 128 | |
dtype = torch.bfloat16 | |
nparts = 3 | |
torch.manual_seed(1) | |
q = 3 * torch.rand(B, Mq, H, K, dtype=dtype, device="cuda") | |
kv = [ | |
[3 * (torch.rand(B, M, H, K, dtype=dtype, device="cuda")) for _ in range(2)] | |
for _ in range(nparts) | |
] | |
q = q.requires_grad_(True) | |
kv = [[j.requires_grad_(True) for j in i] for i in kv] | |
out_parts = [fmha.memory_efficient_attention_partial(q, k, v, op=op) for k, v in kv] | |
attn_split, lse_split = [list(x) for x in zip(*out_parts)] | |
out_merged = fmha.merge_attentions(attn_split, lse_split, write_lse=True)[0] | |
grad_out = torch.rand_like(q) | |
out_merged.backward(grad_out) | |
grad_q_out = q.grad | |
assert q.grad is not None | |
grad_kv_out = [[j.grad for j in i] for i in kv] | |
q = q.detach().requires_grad_(True) | |
kv = [[j.detach().requires_grad_(True) for j in i] for i in kv] | |
k2, v2 = [torch.cat([i[j] for i in kv], dim=1) for j in range(2)] | |
if op is None or isinstance(op, tuple): | |
full_op = op | |
else: | |
full_op = (op, None) | |
out_full = fmha.memory_efficient_attention(q, k2, v2, op=full_op) # type: ignore | |
out_full.backward(grad_out) | |
assert_allclose( | |
out_merged, out_full.to(out_merged.dtype), rtol=1e-2, atol=2e-2, msg="out" | |
) | |
atol = fmha.AttentionBwOpBase.ERROR_ATOL[dtype] * 1.5 | |
rtol = fmha.AttentionBwOpBase.ERROR_RTOL[dtype] | |
assert_allclose(grad_q_out, q.grad, rtol=rtol, atol=atol, msg="qgrad") | |
for i in range(nparts): | |
for j in range(2): | |
assert_allclose( | |
grad_kv_out[i][j], | |
kv[i][j].grad, | |
rtol=rtol, | |
atol=atol, | |
msg=f"kvgrad {i} {j}", | |
) | |
def test_partial_paged( | |
dtype: torch.dtype, op: Type[AttentionFwOpBase], num_queries: int, bmghk: bool | |
): | |
B = 128 | |
N_H_L = 8 | |
D_H = 128 | |
page_size = 256 | |
G = 2 if bmghk else 1 | |
block_tables = torch.zeros((B, 1), dtype=torch.int32, device="cuda") | |
torch.manual_seed(1) | |
output_dtype = torch.float32 if op.SUPPORTS_OUTPUT_DTYPE else None | |
B_T = num_queries * B | |
q = torch.randn((1, B_T, G, N_H_L, D_H), dtype=dtype, device="cuda") | |
k = torch.randn((1, page_size, G, 1, D_H), dtype=dtype, device="cuda") | |
v = torch.randn_like(k) | |
k = k.expand(1, page_size, G, N_H_L, D_H) | |
v = v.expand(1, page_size, G, N_H_L, D_H) | |
if not bmghk: | |
q = q[:, :, 0] | |
k = k[:, :, 0] | |
v = v[:, :, 0] | |
attn_bias = ( | |
fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( | |
q_seqlen=[num_queries] * B, | |
kv_seqlen=[1] + ([100] * (B - 1)), | |
page_size=page_size, | |
block_tables=block_tables, | |
) | |
) | |
if attn_bias not in op.SUPPORTED_ATTN_BIAS_TYPES: | |
pytest.skip("Not supported bias") | |
attn_chunk, lse_chunk = fmha.memory_efficient_attention_partial( | |
q, | |
k, | |
v, | |
attn_bias, | |
op=op, | |
output_dtype=output_dtype, | |
) | |
if bmghk: | |
assert attn_chunk.shape == (1, B_T, G, N_H_L, D_H) | |
assert lse_chunk.shape == ( | |
1, | |
G, | |
N_H_L, | |
B_T, | |
), f"{lse_chunk.shape=}, {(1, G, N_H_L, B_T)=}" | |
else: | |
assert attn_chunk.shape == (1, B_T, N_H_L, D_H) | |
assert lse_chunk.shape == ( | |
1, | |
N_H_L, | |
B_T, | |
), f"{lse_chunk.shape=}, {(1, N_H_L, B_T)=}" | |
def test_merge_attentions_decoding( | |
dtype: torch.dtype, | |
op: Type[AttentionFwOpBase], | |
num_queries: int, | |
bmghk: bool, | |
stack_inputs: bool, | |
): | |
""" | |
Compute decoding attention on chunks of K/V and merge them together. | |
Compare with computing attention on the whole K/V. | |
""" | |
MAX_T = 8192 | |
B = 128 | |
N_H_L = 8 | |
D_H = 128 | |
G = 2 if bmghk else 1 | |
torch.manual_seed(1) | |
output_dtype = torch.float32 if op.SUPPORTS_OUTPUT_DTYPE else None | |
num_chunks = 10 | |
chunk_starts = sorted( | |
torch.randint(low=1, high=MAX_T // 2, size=(num_chunks,)).tolist() | |
) | |
chunk_starts[0] = 0 | |
chunk_starts.append(MAX_T) | |
# We construct sequences so that even the last chunk has a non-empty part of every sequence | |
# as long as the number of queries. | |
# Otherwise the corresponding LSE will be -inf and that'll propagate to the whole sum. | |
# It is possible to teach the kernel to ignore infinite LSEs, but in practical use cases | |
# of merging attention, e.g. a batch of sequences with a common prefix, this condition should be satisfied. | |
k_lens = torch.randint( | |
low=chunk_starts[-2] + num_queries, high=MAX_T, size=(B,) | |
).tolist() | |
q_lens = [num_queries] * B | |
B_T = num_queries * B | |
q = torch.randn((1, B_T, G, N_H_L, D_H), dtype=dtype, device="cuda") | |
k = torch.randn((B, MAX_T, G, 1, D_H), dtype=dtype, device="cuda") | |
v = torch.randn_like(k) | |
if not bmghk: | |
q = q[:, :, 0] | |
# Compute per-chunk attention | |
chunks_output = [] | |
for i in range(num_chunks): | |
chunk_start, chunk_end = chunk_starts[i], chunk_starts[i + 1] | |
k_chunk = k[:, chunk_start:chunk_end, ...] | |
v_chunk = v[:, chunk_start:chunk_end, ...] | |
axk = k_chunk.reshape(-1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H) | |
axv = v_chunk.reshape(-1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H) | |
if not bmghk: | |
axk = axk[:, :, 0] | |
axv = axv[:, :, 0] | |
bias_type = fmha.attn_bias.BlockDiagonalPaddedKeysMask | |
if i + 1 == num_chunks: | |
bias_type = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask | |
attn_bias = bias_type.from_seqlens( | |
q_seqlen=q_lens, | |
kv_padding=chunk_end - chunk_start, | |
kv_seqlen=[max(min(x, chunk_end) - chunk_start, 0) for x in k_lens], | |
) | |
attn_chunk, lse_chunk = fmha.memory_efficient_attention_partial( | |
q, | |
axk, | |
axv, | |
attn_bias, | |
op=op, | |
output_dtype=output_dtype, | |
) | |
if bmghk: | |
assert attn_chunk.shape == (1, B_T, G, N_H_L, D_H) | |
assert lse_chunk.shape == (1, G, N_H_L, B_T) | |
else: | |
assert attn_chunk.shape == (1, B_T, N_H_L, D_H) | |
assert lse_chunk.shape == (1, N_H_L, B_T) | |
chunks_output.append((attn_chunk, lse_chunk)) | |
# Merge attention from all chunks | |
attn_split = [attn_chunk for attn_chunk, _ in chunks_output] | |
lse_split = [lse_chunk for _, lse_chunk in chunks_output] | |
attn_split_ = torch.stack(attn_split) if stack_inputs else attn_split | |
lse_split_ = torch.stack(lse_split) if stack_inputs else lse_split | |
attn_out, lse_out = fmha.merge_attentions( | |
attn_split_, lse_split_, output_dtype=dtype # type: ignore | |
) | |
assert lse_out is not None | |
# Compute attention on the full K/V | |
attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( | |
q_seqlen=q_lens, | |
kv_padding=MAX_T, | |
kv_seqlen=k_lens, | |
) | |
axk = k.view(1, -1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H) | |
axv = v.view(1, -1, G, 1, D_H).expand(1, -1, G, N_H_L, D_H) | |
if not bmghk: | |
axk = axk[:, :, 0] | |
axv = axv[:, :, 0] | |
attn_full, lse_full = fmha.memory_efficient_attention_partial( | |
q, | |
axk, | |
axv, | |
attn_bias, | |
op=op, | |
output_dtype=output_dtype, | |
) | |
assert_allclose( | |
lse_out.to(lse_full.dtype), lse_full, rtol=1e-3, atol=1e-3, msg="lse" | |
) | |
assert_allclose( | |
attn_out.to(attn_full.dtype), attn_full, rtol=1e-3, atol=1e-3, msg="out" | |
) | |
attn_full2 = fmha.memory_efficient_attention_forward( | |
q, | |
axk, | |
axv, | |
attn_bias, | |
op=op, | |
output_dtype=output_dtype, | |
) | |
assert_allclose(attn_full2, attn_full, rtol=1e-3, atol=1e-3, msg="out2") | |
def test_merge_attentions_sharedinput( | |
dtype: torch.dtype, | |
op: Type[AttentionFwOpBase], | |
gqa: bool, | |
): | |
""" | |
Compute decoding attention on chunks of K/V and merge them together. | |
Compare with computing attention on the whole K/V. | |
""" | |
MAX_T = 8192 | |
N_H_L = 16 | |
D_H = 128 | |
G = 2 | |
torch.manual_seed(1) | |
output_dtype = torch.float32 if op.SUPPORTS_OUTPUT_DTYPE else None | |
shared_length = 20 | |
full_lengths = [30, 35, 40] | |
attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( | |
q_seqlen=[1, 1, 1], | |
kv_padding=MAX_T, | |
kv_seqlen=full_lengths, | |
) | |
attn_bias1 = fmha.attn_bias.BlockDiagonalPaddedKeysMask.from_seqlens( | |
q_seqlen=[2, 1], | |
kv_padding=MAX_T, | |
kv_seqlen=[shared_length, 0], | |
) | |
attn_bias2 = fmha.attn_bias.BlockDiagonalGappyKeysMask.from_seqlens( | |
q_seqlen=[1, 1, 1], | |
kv_seqstarts=[shared_length, MAX_T + shared_length, 2 * MAX_T, 3 * MAX_T], | |
kv_seqlen=[ | |
full_lengths[0] - shared_length, | |
full_lengths[1] - shared_length, | |
full_lengths[2], | |
], | |
) | |
q = torch.randn((1, 3, G, N_H_L, D_H), dtype=dtype, device="cuda") | |
k = torch.randn((3, MAX_T, G, 1 if gqa else N_H_L, D_H), dtype=dtype, device="cuda") | |
v = torch.randn_like(k) | |
k[1, :shared_length] = k[0, :shared_length] | |
v[1, :shared_length] = v[0, :shared_length] | |
k = k.flatten(end_dim=1)[None] | |
v = v.flatten(end_dim=1)[None] | |
k = k.expand((1, 3 * MAX_T, G, N_H_L, D_H)) | |
v = v.expand((1, 3 * MAX_T, G, N_H_L, D_H)) | |
attn_chunk1, lse_chunk1 = fmha.memory_efficient_attention_partial( | |
q, | |
k, | |
v, | |
attn_bias1, | |
op=op, | |
output_dtype=output_dtype, | |
) | |
assert attn_chunk1.shape == (1, 3, G, N_H_L, D_H) | |
assert lse_chunk1.shape == (1, G, N_H_L, 3) | |
if gqa: | |
attn_chunk1a, lse_chunk1a = fmha.memory_efficient_attention_partial( | |
q, | |
k.contiguous(), | |
v, | |
attn_bias1, | |
op=op, | |
output_dtype=output_dtype, | |
) | |
assert attn_chunk1a.shape == (1, 3, G, N_H_L, D_H) | |
assert lse_chunk1a.shape == (1, G, N_H_L, 3) | |
assert_allclose( | |
attn_chunk1a.nan_to_num(0, 0, 0), attn_chunk1.nan_to_num(0, 0, 0) | |
) | |
assert_allclose(lse_chunk1a.nan_to_num(0, 0, 0), lse_chunk1.nan_to_num(0, 0, 0)) | |
attn_chunk2, lse_chunk2 = fmha.memory_efficient_attention_partial( | |
q, | |
k, | |
v, | |
attn_bias2, | |
op=op, | |
output_dtype=output_dtype, | |
) | |
assert attn_chunk2.shape == (1, 3, G, N_H_L, D_H) | |
assert lse_chunk2.shape == (1, G, N_H_L, 3) | |
# Merge attention from all chunks | |
attn_out, lse_out = fmha.merge_attentions( | |
[attn_chunk1, attn_chunk2], [lse_chunk1, lse_chunk2], output_dtype=dtype # type: ignore | |
) | |
assert lse_out is not None | |
# Compute attention on the full K/V | |
attn_full, lse_full = fmha.memory_efficient_attention_partial( | |
q, | |
k, | |
v, | |
attn_bias, | |
op=op, | |
output_dtype=output_dtype, | |
) | |
assert_allclose( | |
attn_out.to(attn_full.dtype), attn_full, rtol=1e-2, atol=2e-3, msg="out" | |
) | |
assert_allclose( | |
lse_out.to(lse_full.dtype), lse_full, rtol=1e-3, atol=1e-3, msg="lse" | |
) | |
# Gradient with respect to attention, LSE, or neither | |
def test_merge_attentions_against_ref( | |
bmghk: bool, stack_inputs: bool, grad_var: Optional[str] | |
): | |
split_k = 16 | |
B = 12 | |
M = 137 | |
G = 2 if bmghk else 1 | |
N_H_L = 8 | |
D_H = 128 | |
dtype = torch.float32 | |
attn_split = torch.randn([split_k, B, M, G, N_H_L, D_H], dtype=dtype, device="cuda") | |
lse_split = torch.randn([split_k, B, G, N_H_L, M], dtype=dtype, device="cuda") | |
if not bmghk: | |
attn_split = attn_split[:, :, :, 0] | |
lse_split = lse_split[:, :, 0] | |
if grad_var is not None: | |
attn_split.requires_grad_(True) | |
lse_split.requires_grad_(True) | |
attn_out_ref, lse_out_ref = _merge_attentions_ref(attn_split, lse_split) | |
if grad_var is not None: | |
if grad_var == "attn": | |
out_grad = torch.randn_like(attn_out_ref) | |
attn_out_ref.backward(out_grad) | |
else: | |
out_grad = torch.randn_like(lse_out_ref) | |
lse_out_ref.backward(out_grad) | |
attn_grad_ref, lse_grad_ref = attn_split.grad, lse_split.grad | |
attn_split = attn_split.detach().unbind(0) # type: ignore | |
lse_split = lse_split.detach().unbind(0) # type: ignore | |
for x in attn_split + lse_split: | |
x.requires_grad_(True) | |
x.retain_grad() | |
attn_out, lse_out = fmha.merge_attentions(attn_split, lse_split) | |
torch.testing.assert_close(lse_out, lse_out_ref, rtol=1e-4, atol=1e-4) | |
torch.testing.assert_close(attn_out, attn_out_ref, rtol=1e-4, atol=1e-4) | |
if grad_var is not None: | |
if grad_var == "attn": | |
attn_out.backward(out_grad) | |
else: | |
assert lse_out is not None | |
lse_out.backward(out_grad) | |
attn_grads = [x.grad for x in attn_split] | |
lse_grads = [x.grad for x in lse_split] | |
attn_grad_concat = torch.stack(attn_grads, dim=0) | |
lse_grad_concat = torch.stack(lse_grads, dim=0) | |
if grad_var == "lse": | |
# LSE doesn't depend on attn_split, so when only gradient with respect to LSE is provided as input, | |
# the output gradient with respect to attn_split is zero. | |
# The reference implementation produced None instead of zero in this case | |
attn_grad_ref = torch.zeros_like(attn_grad_concat) | |
torch.testing.assert_close(lse_grad_concat, lse_grad_ref, rtol=1e-4, atol=1e-4) | |
torch.testing.assert_close( | |
attn_grad_concat, attn_grad_ref, rtol=1e-4, atol=1e-4 | |
) | |
def _merge_attentions_ref(attn_split, lse_split): | |
""" | |
attn_split: [split_k, B, M, (G,) H, Kq] | |
lse_split: [split_k, B, (G,) H, M] | |
""" | |
is_bmghk = len(attn_split.shape) == 6 | |
if not is_bmghk: | |
attn_split = attn_split.unsqueeze(3) | |
lse_split = lse_split.unsqueeze(2) | |
lse_split = lse_split[..., None].moveaxis(4, 2) # [split_k, B, M, G, H, 1] | |
lse_max, _ = torch.max(lse_split, dim=0) # [B, M, G, H, 1] | |
sumexp_normalized = torch.exp(lse_split - lse_max) # [split_k, B, M, G, H, 1] | |
denominator = sumexp_normalized.sum(dim=0) # [B, M, G, H, 1] | |
numerator = (sumexp_normalized * attn_split).sum(dim=0) # [B, M, G, H, K] | |
attn_out = numerator / denominator # [B, M_ceil, G, H, Kq] | |
lse_out = lse_max + torch.log(denominator) | |
lse_out = lse_out.squeeze(4).permute(0, 2, 3, 1) # [B, G, H, M] | |
if not is_bmghk: | |
attn_out = attn_out.squeeze(2) | |
lse_out = lse_out.squeeze(1) | |
return attn_out, lse_out | |
# rocm doesn't support backward yet | |
def test_memeff_compile(bias_t, create_bias_inside_compiled: bool, op) -> None: | |
torch.manual_seed(0) | |
torch._dynamo.reset_code_caches() # avoids hitting recompilation limit | |
B, M, H, K = 1, 256, 2, 64 | |
q, k, v, bias = create_tensors( | |
op if op is None else op[0], | |
"cuda", | |
torch.float16, | |
bias_t, | |
B, | |
M, | |
M, | |
H, | |
K, | |
K, | |
fmt="BMHK", | |
) | |
grad = torch.randn_like(q) | |
if create_bias_inside_compiled: | |
bias = None | |
if bias_t not in [None, fmha.attn_bias.LowerTriangularMask]: | |
pytest.skip("Can't create this mask inside compile") | |
if bias is not None: | |
bias.to(q.device) | |
q.requires_grad_(True) | |
k.requires_grad_(True) | |
v.requires_grad_(True) | |
def fmha_fn(q, k, v, bias): | |
if create_bias_inside_compiled and bias_t is not None: | |
bias = bias_t() | |
return fmha.memory_efficient_attention(q, k, v, attn_bias=bias, op=op) | |
# Eager reference | |
out_ref = fmha_fn(q, k, v, bias) | |
out_ref.backward(grad) | |
dq_ref, dk_ref, dv_ref = q.grad, k.grad, v.grad | |
q.grad, k.grad, v.grad = None, None, None | |
# Compiled version | |
fmha_c = torch.compile(fmha_fn, fullgraph=True, dynamic=False) | |
out = fmha_c(q, k, v, bias) | |
out.backward(grad) | |
assert_allclose( | |
out, | |
out_ref, | |
"out", | |
atol=fmha.flash.FwOp.ERROR_ATOL[q.dtype], | |
rtol=fmha.flash.FwOp.ERROR_RTOL[q.dtype], | |
) | |
atol, rtol = ( | |
fmha.flash.BwOp.ERROR_ATOL[q.dtype], | |
fmha.flash.BwOp.ERROR_RTOL[q.dtype], | |
) | |
assert_allclose(q.grad, dq_ref, "dq", atol=atol, rtol=rtol) | |
assert_allclose(k.grad, dk_ref, "dk", atol=atol, rtol=rtol) | |
assert_allclose(v.grad, dv_ref, "dv", atol=atol, rtol=rtol) | |
def test_bias_lower_triangular() -> None: | |
mask = fmha.attn_bias.LowerTriangularMask() | |
mask.detach() | |
def test_bias_lower_triangular_with_bias() -> None: | |
dense_bias = torch.randn([128, 128], dtype=torch.float16, requires_grad=True) | |
grad = torch.randn_like(dense_bias) | |
mask = fmha.attn_bias.LowerTriangularMask() | |
mask_biased = mask.add_bias(dense_bias) | |
mask_biased2 = mask_biased.detach() | |
mask_biased.backward(grad) | |
assert dense_bias.grad is not None | |
assert mask_biased2.grad is None | |
assert_allclose(dense_bias.grad, grad, "dense.grad") | |
# end of file | |