Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Optional, Tuple | |
import pytest | |
import torch | |
import xformers.ops | |
from xformers.ops import fmha | |
from .utils import assert_allclose, disable_tf32, ref_attention_for_test | |
def ref_attention_splitk_bmhk( | |
q, k, v, attn_bias, scale=None, split_k=None, dtype=None | |
) -> torch.Tensor: | |
assert q.ndim == 4 | |
def T(t): | |
return t.permute((0, 2, 1, 3)).reshape( | |
[t.shape[0] * t.shape[2], t.shape[1], t.shape[3]] | |
) | |
if isinstance(attn_bias, xformers.ops.AttentionBias): | |
attn_bias = attn_bias.materialize( | |
(q.shape[0], q.shape[2], q.shape[1], k.shape[1]), | |
device=q.device, | |
dtype=torch.float32, | |
).reshape([q.shape[0] * q.shape[2], q.shape[1], k.shape[1]]) | |
out = ref_attention_splitk( | |
T(q), T(k), T(v), attn_bias, scale=scale, split_k=split_k, dtype=dtype | |
) | |
out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]]) | |
return out.permute((0, 2, 1, 3)) | |
def ref_attention_splitk( | |
q, k, v, attn_bias, scale=None, split_k=2, dtype=None | |
) -> torch.Tensor: | |
if q.ndim == 5: | |
def attn_bias_group(group: int): | |
if getattr(attn_bias, "HOLDS_DENSE_TENSOR", True): | |
return attn_bias[:, group] | |
return attn_bias | |
return torch.stack( | |
[ | |
ref_attention_splitk_bmhk( | |
q[:, :, g], | |
k[:, :, g], | |
v[:, :, g], | |
attn_bias=attn_bias_group(g), | |
split_k=split_k, | |
dtype=dtype, | |
) | |
for g in range(q.shape[2]) | |
], | |
dim=2, | |
) | |
if q.ndim == 4: | |
return ref_attention_splitk_bmhk( | |
q, k, v, attn_bias=attn_bias, split_k=split_k, dtype=dtype | |
) | |
assert q.ndim == 3 | |
if dtype is None: | |
dtype = torch.float32 | |
q = q.to(dtype=dtype) | |
k = k.to(dtype=dtype) | |
v = v.to(dtype=dtype) | |
if scale is None: | |
scale = q.shape[-1] ** -0.5 | |
assert not q.isnan().any() | |
q = q * scale | |
assert not q.isnan().any() | |
if attn_bias is not None: | |
if isinstance(attn_bias, xformers.ops.AttentionBias): | |
# Always create in B,H,Mq,Mk format | |
attn_bias_tensor = attn_bias.materialize( | |
(q.shape[0], 1, q.shape[1], k.shape[1]), | |
device=q.device, | |
dtype=torch.float32, | |
) | |
else: | |
attn_bias_tensor = attn_bias | |
if attn_bias_tensor.ndim == 4: | |
assert q.shape[0] == attn_bias_tensor.shape[0] * attn_bias_tensor.shape[1] | |
attn_bias_tensor = attn_bias_tensor.reshape( | |
[-1, *attn_bias_tensor.shape[2:]] | |
) | |
split_size = k.size(-2) // split_k | |
split_config = {"dim": -2, "split_size_or_sections": split_size} | |
k_split = torch.split(k, **split_config) | |
v_split = torch.split(v, **split_config) | |
attn_bias_split = torch.split( | |
attn_bias_tensor, dim=-1, split_size_or_sections=split_size | |
) | |
def compute_attention_split(q_whole, k_slice, v_slice, attn_bias_slice): | |
p_slice = q_whole @ k_slice.transpose(-2, -1) | |
p_slice += attn_bias_slice | |
row_max = torch.max(p_slice, dim=-1, keepdim=True).values | |
p_slice_scaled = p_slice - row_max | |
p_slice_scaled[p_slice_scaled.isnan()] = float("-inf") | |
s = torch.exp(p_slice_scaled) | |
row_sumexp = torch.sum(s, dim=-1, keepdim=True) | |
attn_slice = s @ v_slice | |
return { | |
"attn_slice": attn_slice, | |
"row_max": row_max, | |
"row_sumexp": row_sumexp, | |
} | |
splits = list(zip(k_split, v_split, attn_bias_split)) | |
slices = list(map(lambda s: compute_attention_split(q, s[0], s[1], s[2]), splits)) | |
out = torch.zeros_like(q) | |
# reduce out over split-k slices | |
global_max = torch.zeros_like(slices[0]["row_max"]).fill_(float("-inf")) | |
global_sumexp = torch.zeros_like(slices[0]["row_sumexp"]) | |
for s in slices: | |
local_out = s["attn_slice"] | |
local_max = s["row_max"] | |
local_sumexp = s["row_sumexp"] | |
log_alpha = -torch.abs(local_max - global_max) | |
alpha = torch.exp(log_alpha) | |
alpha.nan_to_num_(1.0) | |
pick_new = local_max < global_max | |
new_coef = torch.where(pick_new, alpha, 1.0) | |
curr_coef = torch.where(pick_new, 1.0, alpha) | |
out = out * curr_coef + local_out * new_coef | |
global_sumexp = global_sumexp * curr_coef + local_sumexp * new_coef | |
global_max = torch.max(local_max, global_max) | |
out /= global_sumexp | |
return out | |
def _kv_heads_label(kv_heads: Optional[int]) -> str: | |
if kv_heads is None: | |
return "" | |
if kv_heads == 1: | |
return "mq" | |
return f"gqa{kv_heads}" | |
def test_splitk_reference( | |
kv_heads: int, | |
n_heads: int, | |
padding: int, | |
bsz: int, | |
dtype: str, | |
device: str, | |
split_k: int, | |
): | |
dtype_ = {"f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32}[dtype] | |
torch.manual_seed(1) | |
d = 256 | |
num_queries = 1 | |
if kv_heads is not None and kv_heads > 1: | |
k_shape: Tuple[int, ...] = (1, bsz * padding, kv_heads, n_heads, d) | |
q_shape: Tuple[int, ...] = ( | |
1, | |
bsz * num_queries, | |
kv_heads, | |
n_heads, | |
d, | |
) | |
else: | |
k_shape = (1, bsz * padding, n_heads, d) | |
q_shape = (1, bsz * num_queries, n_heads, d) | |
k = torch.rand(k_shape, dtype=dtype_, device=device) | |
k_seqlen = torch.randint(1, padding + 1, (bsz,)).tolist() | |
v = torch.rand_like(k) | |
q = torch.rand(q_shape, dtype=dtype_, device=device) | |
causal_diagonal = torch.tensor( # TODO: make unnecessary | |
[i - 1 for i in k_seqlen], dtype=torch.int32, device=device | |
) | |
if kv_heads is not None: | |
k = k[..., :1, :].expand(k_shape) | |
v = v[..., :1, :].expand(k_shape) | |
attn_bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( | |
q_seqlen=[1] * bsz, | |
kv_seqlen=k_seqlen, | |
causal_diagonal=causal_diagonal, | |
kv_padding=padding, | |
) | |
ref_out = ref_attention_for_test(q, k, v, attn_bias) | |
splitk_out = ref_attention_splitk(q, k, v, attn_bias, None, split_k=split_k) | |
assert_allclose( | |
ref_out, | |
splitk_out, | |
atol=fmha.ck.FwOp.ERROR_ATOL[dtype_], | |
rtol=fmha.ck.FwOp.ERROR_RTOL[dtype_], | |
) | |