Flexstorydiff / xformers /tests /test_splitk_reference.py
FlexTheAi's picture
Upload folder using huggingface_hub
e202b16 verified
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple
import pytest
import torch
import xformers.ops
from xformers.ops import fmha
from .utils import assert_allclose, disable_tf32, ref_attention_for_test
@disable_tf32
def ref_attention_splitk_bmhk(
q, k, v, attn_bias, scale=None, split_k=None, dtype=None
) -> torch.Tensor:
assert q.ndim == 4
def T(t):
return t.permute((0, 2, 1, 3)).reshape(
[t.shape[0] * t.shape[2], t.shape[1], t.shape[3]]
)
if isinstance(attn_bias, xformers.ops.AttentionBias):
attn_bias = attn_bias.materialize(
(q.shape[0], q.shape[2], q.shape[1], k.shape[1]),
device=q.device,
dtype=torch.float32,
).reshape([q.shape[0] * q.shape[2], q.shape[1], k.shape[1]])
out = ref_attention_splitk(
T(q), T(k), T(v), attn_bias, scale=scale, split_k=split_k, dtype=dtype
)
out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]])
return out.permute((0, 2, 1, 3))
@disable_tf32
def ref_attention_splitk(
q, k, v, attn_bias, scale=None, split_k=2, dtype=None
) -> torch.Tensor:
if q.ndim == 5:
def attn_bias_group(group: int):
if getattr(attn_bias, "HOLDS_DENSE_TENSOR", True):
return attn_bias[:, group]
return attn_bias
return torch.stack(
[
ref_attention_splitk_bmhk(
q[:, :, g],
k[:, :, g],
v[:, :, g],
attn_bias=attn_bias_group(g),
split_k=split_k,
dtype=dtype,
)
for g in range(q.shape[2])
],
dim=2,
)
if q.ndim == 4:
return ref_attention_splitk_bmhk(
q, k, v, attn_bias=attn_bias, split_k=split_k, dtype=dtype
)
assert q.ndim == 3
if dtype is None:
dtype = torch.float32
q = q.to(dtype=dtype)
k = k.to(dtype=dtype)
v = v.to(dtype=dtype)
if scale is None:
scale = q.shape[-1] ** -0.5
assert not q.isnan().any()
q = q * scale
assert not q.isnan().any()
if attn_bias is not None:
if isinstance(attn_bias, xformers.ops.AttentionBias):
# Always create in B,H,Mq,Mk format
attn_bias_tensor = attn_bias.materialize(
(q.shape[0], 1, q.shape[1], k.shape[1]),
device=q.device,
dtype=torch.float32,
)
else:
attn_bias_tensor = attn_bias
if attn_bias_tensor.ndim == 4:
assert q.shape[0] == attn_bias_tensor.shape[0] * attn_bias_tensor.shape[1]
attn_bias_tensor = attn_bias_tensor.reshape(
[-1, *attn_bias_tensor.shape[2:]]
)
split_size = k.size(-2) // split_k
split_config = {"dim": -2, "split_size_or_sections": split_size}
k_split = torch.split(k, **split_config)
v_split = torch.split(v, **split_config)
attn_bias_split = torch.split(
attn_bias_tensor, dim=-1, split_size_or_sections=split_size
)
def compute_attention_split(q_whole, k_slice, v_slice, attn_bias_slice):
p_slice = q_whole @ k_slice.transpose(-2, -1)
p_slice += attn_bias_slice
row_max = torch.max(p_slice, dim=-1, keepdim=True).values
p_slice_scaled = p_slice - row_max
p_slice_scaled[p_slice_scaled.isnan()] = float("-inf")
s = torch.exp(p_slice_scaled)
row_sumexp = torch.sum(s, dim=-1, keepdim=True)
attn_slice = s @ v_slice
return {
"attn_slice": attn_slice,
"row_max": row_max,
"row_sumexp": row_sumexp,
}
splits = list(zip(k_split, v_split, attn_bias_split))
slices = list(map(lambda s: compute_attention_split(q, s[0], s[1], s[2]), splits))
out = torch.zeros_like(q)
# reduce out over split-k slices
global_max = torch.zeros_like(slices[0]["row_max"]).fill_(float("-inf"))
global_sumexp = torch.zeros_like(slices[0]["row_sumexp"])
for s in slices:
local_out = s["attn_slice"]
local_max = s["row_max"]
local_sumexp = s["row_sumexp"]
log_alpha = -torch.abs(local_max - global_max)
alpha = torch.exp(log_alpha)
alpha.nan_to_num_(1.0)
pick_new = local_max < global_max
new_coef = torch.where(pick_new, alpha, 1.0)
curr_coef = torch.where(pick_new, 1.0, alpha)
out = out * curr_coef + local_out * new_coef
global_sumexp = global_sumexp * curr_coef + local_sumexp * new_coef
global_max = torch.max(local_max, global_max)
out /= global_sumexp
return out
def _kv_heads_label(kv_heads: Optional[int]) -> str:
if kv_heads is None:
return ""
if kv_heads == 1:
return "mq"
return f"gqa{kv_heads}"
@pytest.mark.parametrize("dtype", ["f32"])
@pytest.mark.parametrize("kv_heads", [None, 1, 2], ids=_kv_heads_label)
@pytest.mark.parametrize("n_heads", [16])
@pytest.mark.parametrize("padding, bsz", [(32, 8), (4096, 1)])
@pytest.mark.parametrize("split_k", [1, 2, 4])
@pytest.mark.parametrize("device", ["cpu"])
def test_splitk_reference(
kv_heads: int,
n_heads: int,
padding: int,
bsz: int,
dtype: str,
device: str,
split_k: int,
):
dtype_ = {"f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32}[dtype]
torch.manual_seed(1)
d = 256
num_queries = 1
if kv_heads is not None and kv_heads > 1:
k_shape: Tuple[int, ...] = (1, bsz * padding, kv_heads, n_heads, d)
q_shape: Tuple[int, ...] = (
1,
bsz * num_queries,
kv_heads,
n_heads,
d,
)
else:
k_shape = (1, bsz * padding, n_heads, d)
q_shape = (1, bsz * num_queries, n_heads, d)
k = torch.rand(k_shape, dtype=dtype_, device=device)
k_seqlen = torch.randint(1, padding + 1, (bsz,)).tolist()
v = torch.rand_like(k)
q = torch.rand(q_shape, dtype=dtype_, device=device)
causal_diagonal = torch.tensor( # TODO: make unnecessary
[i - 1 for i in k_seqlen], dtype=torch.int32, device=device
)
if kv_heads is not None:
k = k[..., :1, :].expand(k_shape)
v = v[..., :1, :].expand(k_shape)
attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=[1] * bsz,
kv_seqlen=k_seqlen,
causal_diagonal=causal_diagonal,
kv_padding=padding,
)
ref_out = ref_attention_for_test(q, k, v, attn_bias)
splitk_out = ref_attention_splitk(q, k, v, attn_bias, None, split_k=split_k)
assert_allclose(
ref_out,
splitk_out,
atol=fmha.ck.FwOp.ERROR_ATOL[dtype_],
rtol=fmha.ck.FwOp.ERROR_RTOL[dtype_],
)