Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import itertools | |
import pytest | |
import torch | |
import xformers.components.attention.attention_patterns as AP | |
from xformers.components.attention.sparsity_config import ( | |
BigBirdSparsityConfig, | |
BSLongformerSparsityConfig, | |
DenseSparsityConfig, | |
FixedSparsityConfig, | |
VariableSparsityConfig, | |
) | |
# baseline implementations | |
def _local_1d_pattern(attn_size: int, window_size: int) -> torch.Tensor: | |
assert ( | |
window_size % 2 == 1 | |
), "The window size is assumed to be odd (counts self-attention + 2 wings)" | |
h_win_size = window_size // 2 | |
attn_shape = (attn_size, attn_size) | |
full_attn = torch.ones(attn_shape, dtype=torch.bool) | |
mask = torch.tril(full_attn, diagonal=h_win_size) | |
mask &= ~torch.tril(full_attn, diagonal=-(h_win_size + 1)) | |
return mask | |
def _generate_2d_grid(H, W): | |
i = torch.arange(H) | |
j = torch.arange(W) | |
i, j = torch.meshgrid(i, j) | |
return i, j | |
def _horizontal_axial_2d_distance(H, W, p=2.0): | |
i, _ = _generate_2d_grid(H, W) | |
ij = i.reshape(-1, 1).float() | |
d = torch.cdist(ij, ij, p=p) | |
return d | |
def _vertical_axial_2d_distance(H, W, p=2.0): | |
_, j = _generate_2d_grid(H, W) | |
ij = j.reshape(-1, 1).float() | |
d = torch.cdist(ij, ij, p=p) | |
return d | |
def _local_2d_distance(H, W, p=2.0): | |
# axial is a special case with p=0 and distance=2 | |
i, j = _generate_2d_grid(H, W) | |
ij = torch.stack([i.flatten(), j.flatten()], 1).float() | |
d = torch.cdist(ij, ij, p=p) | |
return d | |
def _local_2d_gaussian_distribution(H, W, sigma=1.0): | |
d = _local_2d_distance(H, W, p=2.0) ** 2 | |
d = torch.exp(-0.5 * sigma ** (-2.0) * d) | |
return d | |
def test_local_1d_pattern(attn_size, window_size): | |
mask = AP.local_1d_pattern(attn_size, window_size).float() | |
mask_ref = _local_1d_pattern(attn_size, window_size).float() | |
assert torch.allclose(mask, mask_ref) | |
def test_horizontal_axial_2d_distance(H, W, p): | |
d = AP.horizontal_axial_2d_distance(H, W, p=p) | |
d_ref = _horizontal_axial_2d_distance(H, W, p=p) | |
assert torch.allclose(d, d_ref) | |
def test_vertical_axial_2d_distance(H, W, p): | |
d = AP.vertical_axial_2d_distance(H, W, p=p) | |
d_ref = _vertical_axial_2d_distance(H, W, p=p) | |
assert torch.allclose(d, d_ref) | |
def test_local_2d_distance(H, W, p): | |
d = AP.local_2d_distance(H, W, p=p) | |
d_ref = _local_2d_distance(H, W, p=p) | |
assert torch.allclose(d, d_ref) | |
def test_local_2d_gaussian_distribution(H, W, sigma): | |
d = AP.local_2d_gausian_distribution(H, W, sigma=sigma) | |
d_ref = _local_2d_gaussian_distribution(H, W, sigma=sigma) | |
assert torch.allclose(d, d_ref) | |
def test_swin_attention_pattern(H, W, window_size): | |
# test non-shifted case | |
d = AP.swin_attention_pattern(H, W, window_size, shift_size=0) | |
# partition the self-attention into regions of window_size | |
# similar to the window_partition function from the original paper | |
h = H // window_size | |
w = W // window_size | |
d = d.reshape(h, window_size, w, window_size, h, window_size, w, window_size) | |
product = itertools.product(range(h), range(w)) | |
for y, x in product: | |
# every region should fully attend to itself | |
assert torch.all(d[y, :, x, :, y, :, x, :]) | |
for y2, x2 in product: | |
if y == y2 or x == x2: | |
continue | |
# different regions shouldn't attend between each other | |
assert torch.all(~d[y, :, x, :, y2, :, x2, :]) | |
# test shifted case | |
# in the shifted case, the self-attention should be the same | |
# as in the non-shifted case, when we pad the inputs, apply the operations and then | |
# remove the padding from the result | |
d_shifted = AP.swin_attention_pattern( | |
H, W, window_size, shift_size=window_size // 2 | |
) | |
# add padding and remove shift | |
h = H + window_size | |
w = W + window_size | |
d_padded = AP.swin_attention_pattern(h, w, window_size, shift_size=0) | |
d_padded = d_padded.reshape(h, w, h, w) | |
# remove padding elements | |
half_size = window_size // 2 | |
s = slice(half_size, -half_size) | |
d_padded = d_padded[s, s, s, s].reshape(H * W, H * W) | |
assert torch.all(d_padded == d_shifted) | |
def test_dilated_2d_pattern(H, W, k): | |
d = AP.dilated_2d_pattern(H, W, k) | |
d = d.reshape(H, W, H, W) | |
product_HW = itertools.product(range(H), range(W)) | |
product_kk = itertools.product(range(k), range(k)) | |
for h, w in product_HW: | |
i = h % k | |
j = w % k | |
# every kth element is taken | |
assert torch.all(d[h, w][i::k, j::k]) | |
for ii, jj in product_kk: | |
if ii == i and jj == j: | |
continue | |
# and the other elements are discarded | |
assert torch.all(~d[h, w][ii::k, jj::k]) | |
def test_pattern_to_layout(): | |
BLOCK = 16 | |
SIZE = 128 | |
LAYOUT_SIZE = SIZE // BLOCK | |
# All ones | |
mask1 = torch.ones((SIZE, SIZE), dtype=torch.bool) | |
layout1 = AP.pattern_to_layout(mask1, BLOCK) | |
ref1 = torch.ones((LAYOUT_SIZE, LAYOUT_SIZE), dtype=torch.long) | |
assert torch.allclose(layout1, ref1) | |
# Diagonal -> expect block diagonal | |
mask2 = torch.eye(SIZE, dtype=torch.bool) | |
layout2 = AP.pattern_to_layout(mask2, BLOCK) | |
ref2 = torch.eye(LAYOUT_SIZE, dtype=torch.long) | |
assert torch.allclose(layout2, ref2) | |
# Lower triangular, without the diagonal | |
# note that the layout will need to have the diagonal, else the coefficients close enough would not be computed | |
mask3 = torch.tril(torch.ones((SIZE, SIZE)), diagonal=-1).to(torch.bool) | |
layout3 = AP.pattern_to_layout(mask3, BLOCK) | |
ref3 = torch.tril(torch.ones((LAYOUT_SIZE, LAYOUT_SIZE)), diagonal=0).to(torch.long) | |
assert torch.allclose(layout3, ref3) | |
# Handle heads properly | |
mask = torch.cat((mask1, mask2, mask3)) | |
layout = AP.pattern_to_layout(mask, BLOCK) | |
assert torch.allclose(layout, torch.cat((ref1, ref2, ref3))) | |
# Catch problematic dimensions | |
mask_off = torch.ones((SIZE + 3, SIZE), dtype=torch.bool) | |
with pytest.raises(AssertionError): | |
AP.pattern_to_layout(mask_off, BLOCK) | |
def test_alibi_pattern(): | |
mask = AP.alibi_pattern(1e-3, (16, 128, 128)) | |
# Minor, check that all the top left corners are True | |
assert torch.sum(mask[:, 0, 0]) == 16 | |
def test_quick_layouts(): | |
seq_size = 128 | |
block_size = 16 | |
num_heads = 2 | |
# Fixed | |
assert torch.allclose( | |
AP.quick_fixed_layout(num_heads, block_size, seq_size), | |
torch.Tensor( | |
[ | |
[ | |
[1, 1, 1, 1, 0, 0, 0, 1], | |
[1, 1, 1, 1, 0, 0, 0, 1], | |
[1, 1, 1, 1, 0, 0, 0, 1], | |
[1, 1, 1, 1, 0, 0, 0, 1], | |
[0, 0, 0, 1, 1, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1, 1, 1], | |
], | |
[ | |
[1, 1, 1, 1, 0, 0, 0, 1], | |
[1, 1, 1, 1, 0, 0, 0, 1], | |
[1, 1, 1, 1, 0, 0, 0, 1], | |
[1, 1, 1, 1, 0, 0, 0, 1], | |
[0, 0, 0, 1, 1, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1, 1, 1], | |
], | |
] | |
).long(), | |
) | |
# BSLongformer | |
assert torch.allclose( | |
AP.quick_bslongformer_layout(num_heads, block_size, seq_size), | |
torch.Tensor( | |
[ | |
[ | |
[1, 1, 1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 0, 0, 0, 0, 0], | |
[1, 1, 1, 1, 0, 0, 0, 0], | |
[1, 0, 1, 1, 1, 0, 0, 0], | |
[1, 0, 0, 1, 1, 1, 0, 0], | |
[1, 0, 0, 0, 1, 1, 1, 0], | |
[1, 0, 0, 0, 0, 1, 1, 1], | |
[1, 0, 0, 0, 0, 0, 1, 1], | |
], | |
[ | |
[1, 1, 1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 0, 0, 0, 0, 0], | |
[1, 1, 1, 1, 0, 0, 0, 0], | |
[1, 0, 1, 1, 1, 0, 0, 0], | |
[1, 0, 0, 1, 1, 1, 0, 0], | |
[1, 0, 0, 0, 1, 1, 1, 0], | |
[1, 0, 0, 0, 0, 1, 1, 1], | |
[1, 0, 0, 0, 0, 0, 1, 1], | |
], | |
] | |
).long(), | |
) | |
# Variable | |
assert torch.allclose( | |
AP.quick_variable_layout(num_heads, block_size, seq_size), | |
torch.Tensor( | |
[ | |
[ | |
[1, 1, 1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 1, 0, 0, 0, 0], | |
[1, 0, 0, 0, 1, 1, 1, 1], | |
[1, 0, 0, 0, 1, 1, 1, 1], | |
[1, 0, 0, 0, 1, 1, 1, 1], | |
[1, 0, 0, 0, 1, 1, 1, 1], | |
], | |
[ | |
[1, 1, 1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 1, 0, 0, 0, 0], | |
[1, 0, 0, 0, 1, 1, 1, 1], | |
[1, 0, 0, 0, 1, 1, 1, 1], | |
[1, 0, 0, 0, 1, 1, 1, 1], | |
[1, 0, 0, 0, 1, 1, 1, 1], | |
], | |
] | |
).long(), | |
) | |
# BigBird (just the shape) | |
assert AP.quick_bigbird_layout(num_heads, block_size, seq_size).shape == torch.Size( | |
[num_heads, seq_size // block_size, seq_size // block_size] | |
) | |
def test_layout_to_pattern(): | |
torch.allclose( | |
AP.layout_to_pattern( | |
layout=torch.Tensor([[[0, 1], [1, 0]], [[1, 0], [0, 1]]]), block_size=2 | |
), | |
torch.Tensor( | |
[ | |
[ | |
[0.0, 0.0, 1.0, 1.0], | |
[0.0, 0.0, 1.0, 1.0], | |
[1.0, 1.0, 0.0, 0.0], | |
[1.0, 1.0, 0.0, 0.0], | |
], | |
[ | |
[1.0, 1.0, 0.0, 0.0], | |
[1.0, 1.0, 0.0, 0.0], | |
[0.0, 0.0, 1.0, 1.0], | |
[0.0, 0.0, 1.0, 1.0], | |
], | |
] | |
), | |
) | |
def test_dense_sparsity_config(): | |
sc = DenseSparsityConfig(num_heads=1, block_size=16) | |
with pytest.raises(expected_exception=ValueError): | |
sc.setup_layout(seq_len=17) | |
assert torch.allclose( | |
sc.make_layout(seq_len=32), torch.Tensor([[[1, 1], [1, 1]]]).long() | |
) | |
def test_big_bird_sparsity_config(): | |
sc = BigBirdSparsityConfig( | |
num_heads=1, | |
block_size=16, | |
num_random_blocks=2, | |
num_sliding_window_blocks=1, | |
num_global_blocks=1, | |
) | |
with pytest.raises(expected_exception=ValueError): | |
sc.make_layout(seq_len=16) | |
sc = BigBirdSparsityConfig( | |
num_heads=1, | |
block_size=16, | |
num_random_blocks=1, | |
num_sliding_window_blocks=2, | |
num_global_blocks=1, | |
) | |
with pytest.raises(expected_exception=ValueError): | |
sc.make_layout(seq_len=16) | |
sc = BigBirdSparsityConfig( | |
num_heads=1, | |
block_size=16, | |
num_random_blocks=1, | |
num_sliding_window_blocks=1, | |
num_global_blocks=2, | |
) | |
with pytest.raises(expected_exception=ValueError): | |
sc.make_layout(seq_len=16) | |
with pytest.raises(expected_exception=NotImplementedError): | |
BigBirdSparsityConfig(num_heads=1, attention="directional") | |
def test_bslongformer_sparsity_config(): | |
sc = BSLongformerSparsityConfig(num_heads=1, global_block_end_indices=[1]) | |
assert torch.allclose( | |
sc.make_layout(128), | |
torch.Tensor( | |
[ | |
[ | |
[1, 1, 1, 1, 1, 1, 1, 1], | |
[1, 1, 1, 0, 0, 0, 0, 0], | |
[1, 1, 1, 1, 0, 0, 0, 0], | |
[1, 0, 1, 1, 1, 0, 0, 0], | |
[1, 0, 0, 1, 1, 1, 0, 0], | |
[1, 0, 0, 0, 1, 1, 1, 0], | |
[1, 0, 0, 0, 0, 1, 1, 1], | |
[1, 0, 0, 0, 0, 0, 1, 1], | |
] | |
] | |
).long(), | |
) | |
with pytest.raises(expected_exception=ValueError): | |
BSLongformerSparsityConfig(num_heads=1, global_block_end_indices=[]) | |
with pytest.raises(expected_exception=ValueError): | |
BSLongformerSparsityConfig(num_heads=1, global_block_end_indices=[-1]) | |
def test_fixed_sparsity_config(): | |
# chech that the case end < num_blocks is correct | |
sc = FixedSparsityConfig(num_heads=1, horizontal_global_attention=True) | |
assert torch.allclose( | |
sc.make_layout(112), | |
torch.Tensor( | |
[ | |
[ | |
[1, 1, 1, 1, 0, 0, 1], | |
[1, 1, 1, 1, 0, 0, 1], | |
[1, 1, 1, 1, 0, 0, 1], | |
[1, 1, 1, 1, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1, 1], | |
[1, 1, 1, 1, 1, 1, 1], | |
] | |
] | |
).long(), | |
) | |
with pytest.raises(expected_exception=ValueError): | |
FixedSparsityConfig(num_heads=1, num_local_blocks=3, num_global_blocks=2) | |
with pytest.raises(expected_exception=NotImplementedError): | |
FixedSparsityConfig(num_heads=1, attention="directional") | |
with pytest.raises(expected_exception=ValueError): | |
FixedSparsityConfig( | |
num_heads=1, attention="unidirectional", horizontal_global_attention=True | |
) | |
with pytest.raises(expected_exception=ValueError): | |
FixedSparsityConfig( | |
num_heads=1, | |
num_different_global_patterns=2, | |
different_layout_per_head=False, | |
) | |
with pytest.raises(expected_exception=ValueError): | |
FixedSparsityConfig( | |
num_heads=1, | |
num_different_global_patterns=10, | |
num_local_blocks=4, | |
num_global_blocks=1, | |
) | |
def test_variable_sparsity_config(): | |
sc = VariableSparsityConfig(num_heads=1, global_block_end_indices=[1]) | |
assert torch.allclose( | |
sc.make_layout(128), | |
torch.Tensor( | |
[ | |
[ | |
[1, 1, 1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 1, 0, 0, 0, 0], | |
[1, 0, 0, 0, 1, 1, 1, 1], | |
[1, 0, 0, 0, 1, 1, 1, 1], | |
[1, 0, 0, 0, 1, 1, 1, 1], | |
[1, 0, 0, 0, 1, 1, 1, 1], | |
] | |
] | |
).long(), | |
) | |
with pytest.raises(expected_exception=ValueError): | |
VariableSparsityConfig(num_heads=1, global_block_end_indices=[]) | |
with pytest.raises(expected_exception=ValueError): | |
VariableSparsityConfig(num_heads=1, global_block_end_indices=[-1]) | |