Spaces:
Runtime error
Runtime error
File size: 5,006 Bytes
e202b16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import random
import pytest
import torch
from xformers.components.attention import NystromAttention, ScaledDotProduct
from xformers.components.attention.utils import maybe_merge_masks
@pytest.mark.parametrize("pinverse_original_init", [True, False])
@pytest.mark.parametrize("use_razavi_pinverse", [True, False])
@pytest.mark.parametrize("num_landmarks", [30, 33, 905])
def test_nystrom_attention_close_to_sdp(
pinverse_original_init: bool,
use_razavi_pinverse: bool,
num_landmarks: int,
):
# TODO: conv_kernel_size parameter not set to None fails this test. Investigate.
b, s, d = 2, 900, 40
num_heads = 2
seed = 42
torch.random.manual_seed(seed)
random.seed(seed)
nystrom_config = {
"name": "nystrom",
"dropout": 0.0,
"num_landmarks": num_landmarks,
"num_heads": num_heads,
"pinverse_original_init": pinverse_original_init,
"use_razavi_pinverse": use_razavi_pinverse,
}
sdp_config = {
"name": "scaled_dot_product",
"dropout": 0.0,
}
a = torch.rand(b, s, d)
def test_close_to_sdp():
# Make sure that Nystrom and Normal attention are not too far off.
nystrom_attention = NystromAttention(**nystrom_config)
sdp_attention = ScaledDotProduct(**sdp_config)
r_nystrom = nystrom_attention(a, a, a, att_mask=None)
r_sdp = sdp_attention(a, a, a, att_mask=None)
assert torch.allclose(r_nystrom, r_sdp, rtol=0.005, atol=1e-2)
# Make sure that Nystrom and Normal attention are not too far off.
nystrom_attention = NystromAttention(**nystrom_config)
sdp_attention = ScaledDotProduct(**sdp_config)
r_nystrom = nystrom_attention(a, a, a, att_mask=None)
r_sdp = sdp_attention(a, a, a, att_mask=None)
assert torch.allclose(r_nystrom, r_sdp, rtol=0.005, atol=1e-2)
test_close_to_sdp()
@pytest.mark.parametrize("pinverse_original_init", [True])
@pytest.mark.parametrize("use_razavi_pinverse", [True])
@pytest.mark.parametrize("num_landmarks", [30])
def test_nystrom_attention(
pinverse_original_init: bool,
use_razavi_pinverse: bool,
num_landmarks: int,
):
# TODO: conv_kernel_size parameter not set to None fails this test. Investigate.
b, s, d = 2, 900, 40
num_heads = 2
seed = 42
torch.random.manual_seed(seed)
random.seed(seed)
nystrom_config = {
"name": "nystrom",
"dropout": 0.0,
"num_landmarks": num_landmarks,
"num_heads": num_heads,
"pinverse_original_init": pinverse_original_init,
"use_razavi_pinverse": use_razavi_pinverse,
}
sdp_config = {
"name": "scaled_dot_product",
"dropout": 0.0,
}
a = torch.rand(b, s, d)
def test_att_mask_ignored():
# If an sxs attention mask is passed in, it should be ignored.
# Results should be the same as if no mask was passed in.
nystrom_attention = NystromAttention(**nystrom_config)
sdp_attention = ScaledDotProduct(**sdp_config)
key_padding_mask = None
att_mask = torch.randint(0, 2, (s, s)).to(dtype=torch.bool)
sdp_mask = maybe_merge_masks(
att_mask=None,
key_padding_mask=key_padding_mask,
batch_size=b // num_heads,
src_len=s,
num_heads=num_heads,
)
r_nystrom = nystrom_attention(
a, a, a, att_mask=att_mask, key_padding_mask=key_padding_mask
)
r_sdp = sdp_attention(a, a, a, att_mask=sdp_mask)
assert torch.allclose(r_nystrom, r_sdp, rtol=0.005, atol=1e-2)
def test_masking():
# FIXME
# nystrom_config["causal"] = True
# sdp_config["causal"] = True
nystrom_attention = NystromAttention(**nystrom_config)
sdp_attention = ScaledDotProduct(**sdp_config)
key_padding_mask = torch.rand((b // num_heads, s)) > 0.1
att_mask = None
mask = maybe_merge_masks(
att_mask,
key_padding_mask,
batch_size=b // num_heads,
src_len=s,
num_heads=num_heads,
)
r_nystrom = nystrom_attention(a, a, a, key_padding_mask=key_padding_mask)
r_sdp = sdp_attention(a, a, a, att_mask=mask)
# Not very close, but more so testing functionality.
assert torch.allclose(
r_nystrom, r_sdp, rtol=0.1, atol=0.5
), f"max diff {torch.max(torch.abs(r_nystrom-r_sdp))}"
# Error when key padding mask doesn't have expected dimensions.
key_padding_mask = torch.randint(0, 2, (s, b)).to(dtype=torch.bool)
with pytest.raises(AssertionError):
nystrom_attention(a, a, a, key_padding_mask=key_padding_mask)
test_att_mask_ignored()
test_masking()
|