Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
from functools import wraps | |
from typing import List, Optional, Tuple | |
import numpy as np | |
import pytest | |
import torch | |
from xformers.attn_bias_utils import ref_attention, ref_attention_bmhk | |
cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") | |
rocm_only = pytest.mark.skipif( | |
not torch.cuda.is_available() or not torch.version.hip, reason="requires ROCM" | |
) | |
disable_on_rocm = pytest.mark.skipif( | |
not not torch.version.hip, reason="could not be done on ROCM" | |
) | |
def disable_tf32(fn): | |
def wrapped(*args, **kwargs): | |
cuda, cudnn = ( | |
torch.backends.cuda.matmul.allow_tf32, | |
torch.backends.cudnn.allow_tf32, | |
) | |
torch.backends.cuda.matmul.allow_tf32, torch.backends.cudnn.allow_tf32 = ( | |
False, | |
False, | |
) | |
try: | |
return fn(*args, **kwargs) | |
finally: | |
torch.backends.cuda.matmul.allow_tf32, torch.backends.cudnn.allow_tf32 = ( | |
cuda, | |
cudnn, | |
) | |
return wrapped | |
ref_attention_for_test = disable_tf32(ref_attention) | |
ref_attention_bmhk_for_test = disable_tf32(ref_attention_bmhk) | |
def assert_allclose( | |
out: Optional[torch.Tensor], | |
ref: Optional[torch.Tensor], | |
msg: str = "failed", | |
atol: float = 1e-8, | |
rtol: float = 1e-5, | |
) -> None: | |
assert out is not None, f"{msg}: output Tensor is None" | |
assert ref is not None, f"{msg}: reference Tensor is None" | |
assert out.shape == ref.shape, f"Shape: {out.shape} (expected: {ref.shape})" | |
if out.dtype != ref.dtype: | |
assert False, f"out dtype: {out.dtype}, ref dtype: {ref.dtype}" | |
if out.numel() == 0: | |
return | |
flatten_diff = ((out - ref).abs() - atol - ref.abs() * rtol).flatten() | |
max_pos = flatten_diff.argmax() | |
max_location = np.unravel_index(int(max_pos), out.shape) | |
max_diff = flatten_diff[max_pos] | |
num_different = flatten_diff.numel() - torch.count_nonzero(flatten_diff <= 0) | |
percentage = num_different / flatten_diff.numel() | |
del flatten_diff | |
assert torch.allclose(out, ref, rtol=rtol, atol=atol), ( | |
f"{msg}: " | |
f"out={out.flatten()[max_pos]} and ref={ref.flatten()[max_pos]} (diff={max_diff} > 0)" | |
f" at {max_location} of shape {tuple(out.shape)} / atol={atol}, rtol={rtol}" | |
f"/ total failing elements: {num_different} ({percentage*100:.3}%)" | |
) | |
def pack_kv_cache( | |
cache_k: torch.Tensor, | |
cache_v: torch.Tensor, | |
kv_seqlens: List[int], | |
BLOCK_N: int, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
""" | |
Create block tables and pages K/V cache for testing paged attention. | |
Args: | |
cache_k, cache_v: K/V caches, each of shape [B, MAX_T, H_kv, D]. | |
Note that these tensors are unexpanded, | |
i.e. for multiquery case cache_k.shape[2] = 1 | |
kv_seqlens: list of K/V sequence lengths | |
BLOCK_N: number of tokens per per paged attention block | |
B: batch size | |
Returns: | |
block_tables: [B, MAX_BLOCKS] | |
packed_cache_k: [1, total_len_rounded, H_kv, D] | |
packed_cache_v: [1, total_len_rounded, H_kv, D] | |
where total_len_rounded is a sum of K/V seqlens, each rounded up | |
to a multiple of BLOCK_N. | |
""" | |
kv_seqlens_rounded = [(x + BLOCK_N - 1) // BLOCK_N * BLOCK_N for x in kv_seqlens] | |
total_len_rounded = sum(kv_seqlens_rounded) | |
B, MAX_T, H, D = cache_k.shape | |
packed_cache_k = torch.empty( | |
total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype | |
) | |
packed_cache_v = torch.empty( | |
total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype | |
) | |
seqstart = 0 | |
for b in range(B): | |
packed_cache_k[seqstart : seqstart + kv_seqlens[b]] = cache_k[ | |
b, : kv_seqlens[b] | |
].clone() | |
packed_cache_v[seqstart : seqstart + kv_seqlens[b]] = cache_v[ | |
b, : kv_seqlens[b] | |
].clone() | |
seqstart += kv_seqlens_rounded[b] | |
num_blocks_per_row = (MAX_T + BLOCK_N - 1) // BLOCK_N | |
block_tables = ( | |
torch.arange(num_blocks_per_row, device="cuda", dtype=torch.int32) | |
.unsqueeze(0) | |
.expand(B, num_blocks_per_row) | |
) | |
seqstarts = ( | |
( | |
torch.tensor(kv_seqlens_rounded).cumsum(dim=0) | |
- torch.tensor(kv_seqlens_rounded) | |
) | |
.to(device="cuda") | |
.unsqueeze(1) | |
) // BLOCK_N | |
block_tables = (block_tables + seqstarts).contiguous().to(dtype=torch.int32) | |
return ( | |
block_tables, | |
packed_cache_k.unsqueeze(0), | |
packed_cache_v.unsqueeze(0), | |
) | |