Spaces:
Runtime error
Runtime error
File size: 4,820 Bytes
e202b16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from functools import wraps
from typing import List, Optional, Tuple
import numpy as np
import pytest
import torch
from xformers.attn_bias_utils import ref_attention, ref_attention_bmhk
cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
rocm_only = pytest.mark.skipif(
not torch.cuda.is_available() or not torch.version.hip, reason="requires ROCM"
)
disable_on_rocm = pytest.mark.skipif(
not not torch.version.hip, reason="could not be done on ROCM"
)
def disable_tf32(fn):
@wraps(fn)
def wrapped(*args, **kwargs):
cuda, cudnn = (
torch.backends.cuda.matmul.allow_tf32,
torch.backends.cudnn.allow_tf32,
)
torch.backends.cuda.matmul.allow_tf32, torch.backends.cudnn.allow_tf32 = (
False,
False,
)
try:
return fn(*args, **kwargs)
finally:
torch.backends.cuda.matmul.allow_tf32, torch.backends.cudnn.allow_tf32 = (
cuda,
cudnn,
)
return wrapped
ref_attention_for_test = disable_tf32(ref_attention)
ref_attention_bmhk_for_test = disable_tf32(ref_attention_bmhk)
def assert_allclose(
out: Optional[torch.Tensor],
ref: Optional[torch.Tensor],
msg: str = "failed",
atol: float = 1e-8,
rtol: float = 1e-5,
) -> None:
assert out is not None, f"{msg}: output Tensor is None"
assert ref is not None, f"{msg}: reference Tensor is None"
assert out.shape == ref.shape, f"Shape: {out.shape} (expected: {ref.shape})"
if out.dtype != ref.dtype:
assert False, f"out dtype: {out.dtype}, ref dtype: {ref.dtype}"
if out.numel() == 0:
return
flatten_diff = ((out - ref).abs() - atol - ref.abs() * rtol).flatten()
max_pos = flatten_diff.argmax()
max_location = np.unravel_index(int(max_pos), out.shape)
max_diff = flatten_diff[max_pos]
num_different = flatten_diff.numel() - torch.count_nonzero(flatten_diff <= 0)
percentage = num_different / flatten_diff.numel()
del flatten_diff
assert torch.allclose(out, ref, rtol=rtol, atol=atol), (
f"{msg}: "
f"out={out.flatten()[max_pos]} and ref={ref.flatten()[max_pos]} (diff={max_diff} > 0)"
f" at {max_location} of shape {tuple(out.shape)} / atol={atol}, rtol={rtol}"
f"/ total failing elements: {num_different} ({percentage*100:.3}%)"
)
def pack_kv_cache(
cache_k: torch.Tensor,
cache_v: torch.Tensor,
kv_seqlens: List[int],
BLOCK_N: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Create block tables and pages K/V cache for testing paged attention.
Args:
cache_k, cache_v: K/V caches, each of shape [B, MAX_T, H_kv, D].
Note that these tensors are unexpanded,
i.e. for multiquery case cache_k.shape[2] = 1
kv_seqlens: list of K/V sequence lengths
BLOCK_N: number of tokens per per paged attention block
B: batch size
Returns:
block_tables: [B, MAX_BLOCKS]
packed_cache_k: [1, total_len_rounded, H_kv, D]
packed_cache_v: [1, total_len_rounded, H_kv, D]
where total_len_rounded is a sum of K/V seqlens, each rounded up
to a multiple of BLOCK_N.
"""
kv_seqlens_rounded = [(x + BLOCK_N - 1) // BLOCK_N * BLOCK_N for x in kv_seqlens]
total_len_rounded = sum(kv_seqlens_rounded)
B, MAX_T, H, D = cache_k.shape
packed_cache_k = torch.empty(
total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype
)
packed_cache_v = torch.empty(
total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype
)
seqstart = 0
for b in range(B):
packed_cache_k[seqstart : seqstart + kv_seqlens[b]] = cache_k[
b, : kv_seqlens[b]
].clone()
packed_cache_v[seqstart : seqstart + kv_seqlens[b]] = cache_v[
b, : kv_seqlens[b]
].clone()
seqstart += kv_seqlens_rounded[b]
num_blocks_per_row = (MAX_T + BLOCK_N - 1) // BLOCK_N
block_tables = (
torch.arange(num_blocks_per_row, device="cuda", dtype=torch.int32)
.unsqueeze(0)
.expand(B, num_blocks_per_row)
)
seqstarts = (
(
torch.tensor(kv_seqlens_rounded).cumsum(dim=0)
- torch.tensor(kv_seqlens_rounded)
)
.to(device="cuda")
.unsqueeze(1)
) // BLOCK_N
block_tables = (block_tables + seqstarts).contiguous().to(dtype=torch.int32)
return (
block_tables,
packed_cache_k.unsqueeze(0),
packed_cache_v.unsqueeze(0),
)
|