File size: 4,820 Bytes
e202b16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from functools import wraps
from typing import List, Optional, Tuple

import numpy as np
import pytest
import torch

from xformers.attn_bias_utils import ref_attention, ref_attention_bmhk

cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
rocm_only = pytest.mark.skipif(
    not torch.cuda.is_available() or not torch.version.hip, reason="requires ROCM"
)
disable_on_rocm = pytest.mark.skipif(
    not not torch.version.hip, reason="could not be done on ROCM"
)


def disable_tf32(fn):
    @wraps(fn)
    def wrapped(*args, **kwargs):
        cuda, cudnn = (
            torch.backends.cuda.matmul.allow_tf32,
            torch.backends.cudnn.allow_tf32,
        )
        torch.backends.cuda.matmul.allow_tf32, torch.backends.cudnn.allow_tf32 = (
            False,
            False,
        )
        try:
            return fn(*args, **kwargs)
        finally:
            torch.backends.cuda.matmul.allow_tf32, torch.backends.cudnn.allow_tf32 = (
                cuda,
                cudnn,
            )

    return wrapped


ref_attention_for_test = disable_tf32(ref_attention)
ref_attention_bmhk_for_test = disable_tf32(ref_attention_bmhk)


def assert_allclose(
    out: Optional[torch.Tensor],
    ref: Optional[torch.Tensor],
    msg: str = "failed",
    atol: float = 1e-8,
    rtol: float = 1e-5,
) -> None:
    assert out is not None, f"{msg}: output Tensor is None"
    assert ref is not None, f"{msg}: reference Tensor is None"
    assert out.shape == ref.shape, f"Shape: {out.shape} (expected: {ref.shape})"
    if out.dtype != ref.dtype:
        assert False, f"out dtype: {out.dtype}, ref dtype: {ref.dtype}"
    if out.numel() == 0:
        return
    flatten_diff = ((out - ref).abs() - atol - ref.abs() * rtol).flatten()
    max_pos = flatten_diff.argmax()
    max_location = np.unravel_index(int(max_pos), out.shape)
    max_diff = flatten_diff[max_pos]
    num_different = flatten_diff.numel() - torch.count_nonzero(flatten_diff <= 0)
    percentage = num_different / flatten_diff.numel()
    del flatten_diff
    assert torch.allclose(out, ref, rtol=rtol, atol=atol), (
        f"{msg}: "
        f"out={out.flatten()[max_pos]} and ref={ref.flatten()[max_pos]} (diff={max_diff} > 0)"
        f" at {max_location} of shape {tuple(out.shape)} / atol={atol}, rtol={rtol}"
        f"/ total failing elements: {num_different} ({percentage*100:.3}%)"
    )


def pack_kv_cache(
    cache_k: torch.Tensor,
    cache_v: torch.Tensor,
    kv_seqlens: List[int],
    BLOCK_N: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Create block tables and pages K/V cache for testing paged attention.
    Args:
        cache_k, cache_v: K/V caches, each of shape [B, MAX_T, H_kv, D].
            Note that these tensors are unexpanded,
            i.e. for multiquery case cache_k.shape[2] = 1
        kv_seqlens: list of K/V sequence lengths
        BLOCK_N: number of tokens per per paged attention block
        B: batch size
    Returns:
        block_tables: [B, MAX_BLOCKS]
        packed_cache_k: [1, total_len_rounded, H_kv, D]
        packed_cache_v: [1, total_len_rounded, H_kv, D]
    where total_len_rounded is a sum of K/V seqlens, each rounded up
    to a multiple of BLOCK_N.
    """

    kv_seqlens_rounded = [(x + BLOCK_N - 1) // BLOCK_N * BLOCK_N for x in kv_seqlens]

    total_len_rounded = sum(kv_seqlens_rounded)

    B, MAX_T, H, D = cache_k.shape

    packed_cache_k = torch.empty(
        total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype
    )
    packed_cache_v = torch.empty(
        total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype
    )
    seqstart = 0
    for b in range(B):
        packed_cache_k[seqstart : seqstart + kv_seqlens[b]] = cache_k[
            b, : kv_seqlens[b]
        ].clone()
        packed_cache_v[seqstart : seqstart + kv_seqlens[b]] = cache_v[
            b, : kv_seqlens[b]
        ].clone()
        seqstart += kv_seqlens_rounded[b]

    num_blocks_per_row = (MAX_T + BLOCK_N - 1) // BLOCK_N
    block_tables = (
        torch.arange(num_blocks_per_row, device="cuda", dtype=torch.int32)
        .unsqueeze(0)
        .expand(B, num_blocks_per_row)
    )
    seqstarts = (
        (
            torch.tensor(kv_seqlens_rounded).cumsum(dim=0)
            - torch.tensor(kv_seqlens_rounded)
        )
        .to(device="cuda")
        .unsqueeze(1)
    ) // BLOCK_N
    block_tables = (block_tables + seqstarts).contiguous().to(dtype=torch.int32)
    return (
        block_tables,
        packed_cache_k.unsqueeze(0),
        packed_cache_v.unsqueeze(0),
    )