FlexTheAi's picture
Upload folder using huggingface_hub
e202b16 verified
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from functools import wraps
from typing import List, Optional, Tuple
import numpy as np
import pytest
import torch
from xformers.attn_bias_utils import ref_attention, ref_attention_bmhk
cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
rocm_only = pytest.mark.skipif(
not torch.cuda.is_available() or not torch.version.hip, reason="requires ROCM"
)
disable_on_rocm = pytest.mark.skipif(
not not torch.version.hip, reason="could not be done on ROCM"
)
def disable_tf32(fn):
@wraps(fn)
def wrapped(*args, **kwargs):
cuda, cudnn = (
torch.backends.cuda.matmul.allow_tf32,
torch.backends.cudnn.allow_tf32,
)
torch.backends.cuda.matmul.allow_tf32, torch.backends.cudnn.allow_tf32 = (
False,
False,
)
try:
return fn(*args, **kwargs)
finally:
torch.backends.cuda.matmul.allow_tf32, torch.backends.cudnn.allow_tf32 = (
cuda,
cudnn,
)
return wrapped
ref_attention_for_test = disable_tf32(ref_attention)
ref_attention_bmhk_for_test = disable_tf32(ref_attention_bmhk)
def assert_allclose(
out: Optional[torch.Tensor],
ref: Optional[torch.Tensor],
msg: str = "failed",
atol: float = 1e-8,
rtol: float = 1e-5,
) -> None:
assert out is not None, f"{msg}: output Tensor is None"
assert ref is not None, f"{msg}: reference Tensor is None"
assert out.shape == ref.shape, f"Shape: {out.shape} (expected: {ref.shape})"
if out.dtype != ref.dtype:
assert False, f"out dtype: {out.dtype}, ref dtype: {ref.dtype}"
if out.numel() == 0:
return
flatten_diff = ((out - ref).abs() - atol - ref.abs() * rtol).flatten()
max_pos = flatten_diff.argmax()
max_location = np.unravel_index(int(max_pos), out.shape)
max_diff = flatten_diff[max_pos]
num_different = flatten_diff.numel() - torch.count_nonzero(flatten_diff <= 0)
percentage = num_different / flatten_diff.numel()
del flatten_diff
assert torch.allclose(out, ref, rtol=rtol, atol=atol), (
f"{msg}: "
f"out={out.flatten()[max_pos]} and ref={ref.flatten()[max_pos]} (diff={max_diff} > 0)"
f" at {max_location} of shape {tuple(out.shape)} / atol={atol}, rtol={rtol}"
f"/ total failing elements: {num_different} ({percentage*100:.3}%)"
)
def pack_kv_cache(
cache_k: torch.Tensor,
cache_v: torch.Tensor,
kv_seqlens: List[int],
BLOCK_N: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Create block tables and pages K/V cache for testing paged attention.
Args:
cache_k, cache_v: K/V caches, each of shape [B, MAX_T, H_kv, D].
Note that these tensors are unexpanded,
i.e. for multiquery case cache_k.shape[2] = 1
kv_seqlens: list of K/V sequence lengths
BLOCK_N: number of tokens per per paged attention block
B: batch size
Returns:
block_tables: [B, MAX_BLOCKS]
packed_cache_k: [1, total_len_rounded, H_kv, D]
packed_cache_v: [1, total_len_rounded, H_kv, D]
where total_len_rounded is a sum of K/V seqlens, each rounded up
to a multiple of BLOCK_N.
"""
kv_seqlens_rounded = [(x + BLOCK_N - 1) // BLOCK_N * BLOCK_N for x in kv_seqlens]
total_len_rounded = sum(kv_seqlens_rounded)
B, MAX_T, H, D = cache_k.shape
packed_cache_k = torch.empty(
total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype
)
packed_cache_v = torch.empty(
total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype
)
seqstart = 0
for b in range(B):
packed_cache_k[seqstart : seqstart + kv_seqlens[b]] = cache_k[
b, : kv_seqlens[b]
].clone()
packed_cache_v[seqstart : seqstart + kv_seqlens[b]] = cache_v[
b, : kv_seqlens[b]
].clone()
seqstart += kv_seqlens_rounded[b]
num_blocks_per_row = (MAX_T + BLOCK_N - 1) // BLOCK_N
block_tables = (
torch.arange(num_blocks_per_row, device="cuda", dtype=torch.int32)
.unsqueeze(0)
.expand(B, num_blocks_per_row)
)
seqstarts = (
(
torch.tensor(kv_seqlens_rounded).cumsum(dim=0)
- torch.tensor(kv_seqlens_rounded)
)
.to(device="cuda")
.unsqueeze(1)
) // BLOCK_N
block_tables = (block_tables + seqstarts).contiguous().to(dtype=torch.int32)
return (
block_tables,
packed_cache_k.unsqueeze(0),
packed_cache_v.unsqueeze(0),
)