# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. from functools import wraps from typing import List, Optional, Tuple import numpy as np import pytest import torch from xformers.attn_bias_utils import ref_attention, ref_attention_bmhk cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") rocm_only = pytest.mark.skipif( not torch.cuda.is_available() or not torch.version.hip, reason="requires ROCM" ) disable_on_rocm = pytest.mark.skipif( not not torch.version.hip, reason="could not be done on ROCM" ) def disable_tf32(fn): @wraps(fn) def wrapped(*args, **kwargs): cuda, cudnn = ( torch.backends.cuda.matmul.allow_tf32, torch.backends.cudnn.allow_tf32, ) torch.backends.cuda.matmul.allow_tf32, torch.backends.cudnn.allow_tf32 = ( False, False, ) try: return fn(*args, **kwargs) finally: torch.backends.cuda.matmul.allow_tf32, torch.backends.cudnn.allow_tf32 = ( cuda, cudnn, ) return wrapped ref_attention_for_test = disable_tf32(ref_attention) ref_attention_bmhk_for_test = disable_tf32(ref_attention_bmhk) def assert_allclose( out: Optional[torch.Tensor], ref: Optional[torch.Tensor], msg: str = "failed", atol: float = 1e-8, rtol: float = 1e-5, ) -> None: assert out is not None, f"{msg}: output Tensor is None" assert ref is not None, f"{msg}: reference Tensor is None" assert out.shape == ref.shape, f"Shape: {out.shape} (expected: {ref.shape})" if out.dtype != ref.dtype: assert False, f"out dtype: {out.dtype}, ref dtype: {ref.dtype}" if out.numel() == 0: return flatten_diff = ((out - ref).abs() - atol - ref.abs() * rtol).flatten() max_pos = flatten_diff.argmax() max_location = np.unravel_index(int(max_pos), out.shape) max_diff = flatten_diff[max_pos] num_different = flatten_diff.numel() - torch.count_nonzero(flatten_diff <= 0) percentage = num_different / flatten_diff.numel() del flatten_diff assert torch.allclose(out, ref, rtol=rtol, atol=atol), ( f"{msg}: " f"out={out.flatten()[max_pos]} and ref={ref.flatten()[max_pos]} (diff={max_diff} > 0)" f" at {max_location} of shape {tuple(out.shape)} / atol={atol}, rtol={rtol}" f"/ total failing elements: {num_different} ({percentage*100:.3}%)" ) def pack_kv_cache( cache_k: torch.Tensor, cache_v: torch.Tensor, kv_seqlens: List[int], BLOCK_N: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Create block tables and pages K/V cache for testing paged attention. Args: cache_k, cache_v: K/V caches, each of shape [B, MAX_T, H_kv, D]. Note that these tensors are unexpanded, i.e. for multiquery case cache_k.shape[2] = 1 kv_seqlens: list of K/V sequence lengths BLOCK_N: number of tokens per per paged attention block B: batch size Returns: block_tables: [B, MAX_BLOCKS] packed_cache_k: [1, total_len_rounded, H_kv, D] packed_cache_v: [1, total_len_rounded, H_kv, D] where total_len_rounded is a sum of K/V seqlens, each rounded up to a multiple of BLOCK_N. """ kv_seqlens_rounded = [(x + BLOCK_N - 1) // BLOCK_N * BLOCK_N for x in kv_seqlens] total_len_rounded = sum(kv_seqlens_rounded) B, MAX_T, H, D = cache_k.shape packed_cache_k = torch.empty( total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype ) packed_cache_v = torch.empty( total_len_rounded, H, D, device=cache_k.device, dtype=cache_k.dtype ) seqstart = 0 for b in range(B): packed_cache_k[seqstart : seqstart + kv_seqlens[b]] = cache_k[ b, : kv_seqlens[b] ].clone() packed_cache_v[seqstart : seqstart + kv_seqlens[b]] = cache_v[ b, : kv_seqlens[b] ].clone() seqstart += kv_seqlens_rounded[b] num_blocks_per_row = (MAX_T + BLOCK_N - 1) // BLOCK_N block_tables = ( torch.arange(num_blocks_per_row, device="cuda", dtype=torch.int32) .unsqueeze(0) .expand(B, num_blocks_per_row) ) seqstarts = ( ( torch.tensor(kv_seqlens_rounded).cumsum(dim=0) - torch.tensor(kv_seqlens_rounded) ) .to(device="cuda") .unsqueeze(1) ) // BLOCK_N block_tables = (block_tables + seqstarts).contiguous().to(dtype=torch.int32) return ( block_tables, packed_cache_k.unsqueeze(0), packed_cache_v.unsqueeze(0), )