Flexstorydiff / xformers /tests /test_attention_patterns.py
FlexTheAi's picture
Upload folder using huggingface_hub
e202b16 verified
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import pytest
import torch
import xformers.components.attention.attention_patterns as AP
from xformers.components.attention.sparsity_config import (
BigBirdSparsityConfig,
BSLongformerSparsityConfig,
DenseSparsityConfig,
FixedSparsityConfig,
VariableSparsityConfig,
)
# baseline implementations
def _local_1d_pattern(attn_size: int, window_size: int) -> torch.Tensor:
assert (
window_size % 2 == 1
), "The window size is assumed to be odd (counts self-attention + 2 wings)"
h_win_size = window_size // 2
attn_shape = (attn_size, attn_size)
full_attn = torch.ones(attn_shape, dtype=torch.bool)
mask = torch.tril(full_attn, diagonal=h_win_size)
mask &= ~torch.tril(full_attn, diagonal=-(h_win_size + 1))
return mask
def _generate_2d_grid(H, W):
i = torch.arange(H)
j = torch.arange(W)
i, j = torch.meshgrid(i, j)
return i, j
def _horizontal_axial_2d_distance(H, W, p=2.0):
i, _ = _generate_2d_grid(H, W)
ij = i.reshape(-1, 1).float()
d = torch.cdist(ij, ij, p=p)
return d
def _vertical_axial_2d_distance(H, W, p=2.0):
_, j = _generate_2d_grid(H, W)
ij = j.reshape(-1, 1).float()
d = torch.cdist(ij, ij, p=p)
return d
def _local_2d_distance(H, W, p=2.0):
# axial is a special case with p=0 and distance=2
i, j = _generate_2d_grid(H, W)
ij = torch.stack([i.flatten(), j.flatten()], 1).float()
d = torch.cdist(ij, ij, p=p)
return d
def _local_2d_gaussian_distribution(H, W, sigma=1.0):
d = _local_2d_distance(H, W, p=2.0) ** 2
d = torch.exp(-0.5 * sigma ** (-2.0) * d)
return d
@pytest.mark.parametrize("window_size", [3, 7, 11])
@pytest.mark.parametrize("attn_size", [50, 51, 64])
def test_local_1d_pattern(attn_size, window_size):
mask = AP.local_1d_pattern(attn_size, window_size).float()
mask_ref = _local_1d_pattern(attn_size, window_size).float()
assert torch.allclose(mask, mask_ref)
@pytest.mark.parametrize("p", [0, 1, 2])
@pytest.mark.parametrize("W", [5, 7, 10])
@pytest.mark.parametrize("H", [5, 7, 10])
def test_horizontal_axial_2d_distance(H, W, p):
d = AP.horizontal_axial_2d_distance(H, W, p=p)
d_ref = _horizontal_axial_2d_distance(H, W, p=p)
assert torch.allclose(d, d_ref)
@pytest.mark.parametrize("p", [0, 1, 2])
@pytest.mark.parametrize("W", [5, 7, 10])
@pytest.mark.parametrize("H", [5, 7, 10])
def test_vertical_axial_2d_distance(H, W, p):
d = AP.vertical_axial_2d_distance(H, W, p=p)
d_ref = _vertical_axial_2d_distance(H, W, p=p)
assert torch.allclose(d, d_ref)
@pytest.mark.parametrize("p", [0, 1, 2])
@pytest.mark.parametrize("W", [5, 7, 10])
@pytest.mark.parametrize("H", [5, 7, 10])
def test_local_2d_distance(H, W, p):
d = AP.local_2d_distance(H, W, p=p)
d_ref = _local_2d_distance(H, W, p=p)
assert torch.allclose(d, d_ref)
@pytest.mark.parametrize("sigma", [0.5, 1, 2])
@pytest.mark.parametrize("W", [5, 7, 10])
@pytest.mark.parametrize("H", [5, 7, 10])
def test_local_2d_gaussian_distribution(H, W, sigma):
d = AP.local_2d_gausian_distribution(H, W, sigma=sigma)
d_ref = _local_2d_gaussian_distribution(H, W, sigma=sigma)
assert torch.allclose(d, d_ref)
@pytest.mark.parametrize("window_size", [2, 4])
@pytest.mark.parametrize("W", [8, 16])
@pytest.mark.parametrize("H", [8, 16])
def test_swin_attention_pattern(H, W, window_size):
# test non-shifted case
d = AP.swin_attention_pattern(H, W, window_size, shift_size=0)
# partition the self-attention into regions of window_size
# similar to the window_partition function from the original paper
h = H // window_size
w = W // window_size
d = d.reshape(h, window_size, w, window_size, h, window_size, w, window_size)
product = itertools.product(range(h), range(w))
for y, x in product:
# every region should fully attend to itself
assert torch.all(d[y, :, x, :, y, :, x, :])
for y2, x2 in product:
if y == y2 or x == x2:
continue
# different regions shouldn't attend between each other
assert torch.all(~d[y, :, x, :, y2, :, x2, :])
# test shifted case
# in the shifted case, the self-attention should be the same
# as in the non-shifted case, when we pad the inputs, apply the operations and then
# remove the padding from the result
d_shifted = AP.swin_attention_pattern(
H, W, window_size, shift_size=window_size // 2
)
# add padding and remove shift
h = H + window_size
w = W + window_size
d_padded = AP.swin_attention_pattern(h, w, window_size, shift_size=0)
d_padded = d_padded.reshape(h, w, h, w)
# remove padding elements
half_size = window_size // 2
s = slice(half_size, -half_size)
d_padded = d_padded[s, s, s, s].reshape(H * W, H * W)
assert torch.all(d_padded == d_shifted)
@pytest.mark.parametrize("k", [2, 3])
@pytest.mark.parametrize("W", [8, 15])
@pytest.mark.parametrize("H", [8, 15])
def test_dilated_2d_pattern(H, W, k):
d = AP.dilated_2d_pattern(H, W, k)
d = d.reshape(H, W, H, W)
product_HW = itertools.product(range(H), range(W))
product_kk = itertools.product(range(k), range(k))
for h, w in product_HW:
i = h % k
j = w % k
# every kth element is taken
assert torch.all(d[h, w][i::k, j::k])
for ii, jj in product_kk:
if ii == i and jj == j:
continue
# and the other elements are discarded
assert torch.all(~d[h, w][ii::k, jj::k])
def test_pattern_to_layout():
BLOCK = 16
SIZE = 128
LAYOUT_SIZE = SIZE // BLOCK
# All ones
mask1 = torch.ones((SIZE, SIZE), dtype=torch.bool)
layout1 = AP.pattern_to_layout(mask1, BLOCK)
ref1 = torch.ones((LAYOUT_SIZE, LAYOUT_SIZE), dtype=torch.long)
assert torch.allclose(layout1, ref1)
# Diagonal -> expect block diagonal
mask2 = torch.eye(SIZE, dtype=torch.bool)
layout2 = AP.pattern_to_layout(mask2, BLOCK)
ref2 = torch.eye(LAYOUT_SIZE, dtype=torch.long)
assert torch.allclose(layout2, ref2)
# Lower triangular, without the diagonal
# note that the layout will need to have the diagonal, else the coefficients close enough would not be computed
mask3 = torch.tril(torch.ones((SIZE, SIZE)), diagonal=-1).to(torch.bool)
layout3 = AP.pattern_to_layout(mask3, BLOCK)
ref3 = torch.tril(torch.ones((LAYOUT_SIZE, LAYOUT_SIZE)), diagonal=0).to(torch.long)
assert torch.allclose(layout3, ref3)
# Handle heads properly
mask = torch.cat((mask1, mask2, mask3))
layout = AP.pattern_to_layout(mask, BLOCK)
assert torch.allclose(layout, torch.cat((ref1, ref2, ref3)))
# Catch problematic dimensions
mask_off = torch.ones((SIZE + 3, SIZE), dtype=torch.bool)
with pytest.raises(AssertionError):
AP.pattern_to_layout(mask_off, BLOCK)
def test_alibi_pattern():
mask = AP.alibi_pattern(1e-3, (16, 128, 128))
# Minor, check that all the top left corners are True
assert torch.sum(mask[:, 0, 0]) == 16
def test_quick_layouts():
seq_size = 128
block_size = 16
num_heads = 2
# Fixed
assert torch.allclose(
AP.quick_fixed_layout(num_heads, block_size, seq_size),
torch.Tensor(
[
[
[1, 1, 1, 1, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 0, 0, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
],
[
[1, 1, 1, 1, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 0, 0, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 1],
],
]
).long(),
)
# BSLongformer
assert torch.allclose(
AP.quick_bslongformer_layout(num_heads, block_size, seq_size),
torch.Tensor(
[
[
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 0, 1, 1, 1, 0, 0, 0],
[1, 0, 0, 1, 1, 1, 0, 0],
[1, 0, 0, 0, 1, 1, 1, 0],
[1, 0, 0, 0, 0, 1, 1, 1],
[1, 0, 0, 0, 0, 0, 1, 1],
],
[
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 0, 1, 1, 1, 0, 0, 0],
[1, 0, 0, 1, 1, 1, 0, 0],
[1, 0, 0, 0, 1, 1, 1, 0],
[1, 0, 0, 0, 0, 1, 1, 1],
[1, 0, 0, 0, 0, 0, 1, 1],
],
]
).long(),
)
# Variable
assert torch.allclose(
AP.quick_variable_layout(num_heads, block_size, seq_size),
torch.Tensor(
[
[
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1, 1],
],
[
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1, 1],
],
]
).long(),
)
# BigBird (just the shape)
assert AP.quick_bigbird_layout(num_heads, block_size, seq_size).shape == torch.Size(
[num_heads, seq_size // block_size, seq_size // block_size]
)
def test_layout_to_pattern():
torch.allclose(
AP.layout_to_pattern(
layout=torch.Tensor([[[0, 1], [1, 0]], [[1, 0], [0, 1]]]), block_size=2
),
torch.Tensor(
[
[
[0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 1.0, 1.0],
[1.0, 1.0, 0.0, 0.0],
[1.0, 1.0, 0.0, 0.0],
],
[
[1.0, 1.0, 0.0, 0.0],
[1.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 1.0, 1.0],
],
]
),
)
def test_dense_sparsity_config():
sc = DenseSparsityConfig(num_heads=1, block_size=16)
with pytest.raises(expected_exception=ValueError):
sc.setup_layout(seq_len=17)
assert torch.allclose(
sc.make_layout(seq_len=32), torch.Tensor([[[1, 1], [1, 1]]]).long()
)
def test_big_bird_sparsity_config():
sc = BigBirdSparsityConfig(
num_heads=1,
block_size=16,
num_random_blocks=2,
num_sliding_window_blocks=1,
num_global_blocks=1,
)
with pytest.raises(expected_exception=ValueError):
sc.make_layout(seq_len=16)
sc = BigBirdSparsityConfig(
num_heads=1,
block_size=16,
num_random_blocks=1,
num_sliding_window_blocks=2,
num_global_blocks=1,
)
with pytest.raises(expected_exception=ValueError):
sc.make_layout(seq_len=16)
sc = BigBirdSparsityConfig(
num_heads=1,
block_size=16,
num_random_blocks=1,
num_sliding_window_blocks=1,
num_global_blocks=2,
)
with pytest.raises(expected_exception=ValueError):
sc.make_layout(seq_len=16)
with pytest.raises(expected_exception=NotImplementedError):
BigBirdSparsityConfig(num_heads=1, attention="directional")
def test_bslongformer_sparsity_config():
sc = BSLongformerSparsityConfig(num_heads=1, global_block_end_indices=[1])
assert torch.allclose(
sc.make_layout(128),
torch.Tensor(
[
[
[1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 0, 1, 1, 1, 0, 0, 0],
[1, 0, 0, 1, 1, 1, 0, 0],
[1, 0, 0, 0, 1, 1, 1, 0],
[1, 0, 0, 0, 0, 1, 1, 1],
[1, 0, 0, 0, 0, 0, 1, 1],
]
]
).long(),
)
with pytest.raises(expected_exception=ValueError):
BSLongformerSparsityConfig(num_heads=1, global_block_end_indices=[])
with pytest.raises(expected_exception=ValueError):
BSLongformerSparsityConfig(num_heads=1, global_block_end_indices=[-1])
def test_fixed_sparsity_config():
# chech that the case end < num_blocks is correct
sc = FixedSparsityConfig(num_heads=1, horizontal_global_attention=True)
assert torch.allclose(
sc.make_layout(112),
torch.Tensor(
[
[
[1, 1, 1, 1, 0, 0, 1],
[1, 1, 1, 1, 0, 0, 1],
[1, 1, 1, 1, 0, 0, 1],
[1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1],
]
]
).long(),
)
with pytest.raises(expected_exception=ValueError):
FixedSparsityConfig(num_heads=1, num_local_blocks=3, num_global_blocks=2)
with pytest.raises(expected_exception=NotImplementedError):
FixedSparsityConfig(num_heads=1, attention="directional")
with pytest.raises(expected_exception=ValueError):
FixedSparsityConfig(
num_heads=1, attention="unidirectional", horizontal_global_attention=True
)
with pytest.raises(expected_exception=ValueError):
FixedSparsityConfig(
num_heads=1,
num_different_global_patterns=2,
different_layout_per_head=False,
)
with pytest.raises(expected_exception=ValueError):
FixedSparsityConfig(
num_heads=1,
num_different_global_patterns=10,
num_local_blocks=4,
num_global_blocks=1,
)
def test_variable_sparsity_config():
sc = VariableSparsityConfig(num_heads=1, global_block_end_indices=[1])
assert torch.allclose(
sc.make_layout(128),
torch.Tensor(
[
[
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 1, 1, 1, 1],
]
]
).long(),
)
with pytest.raises(expected_exception=ValueError):
VariableSparsityConfig(num_heads=1, global_block_end_indices=[])
with pytest.raises(expected_exception=ValueError):
VariableSparsityConfig(num_heads=1, global_block_end_indices=[-1])