File size: 5,006 Bytes
e202b16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import random

import pytest
import torch

from xformers.components.attention import NystromAttention, ScaledDotProduct
from xformers.components.attention.utils import maybe_merge_masks


@pytest.mark.parametrize("pinverse_original_init", [True, False])
@pytest.mark.parametrize("use_razavi_pinverse", [True, False])
@pytest.mark.parametrize("num_landmarks", [30, 33, 905])
def test_nystrom_attention_close_to_sdp(
    pinverse_original_init: bool,
    use_razavi_pinverse: bool,
    num_landmarks: int,
):
    # TODO: conv_kernel_size parameter not set to None fails this test. Investigate.
    b, s, d = 2, 900, 40
    num_heads = 2
    seed = 42
    torch.random.manual_seed(seed)
    random.seed(seed)

    nystrom_config = {
        "name": "nystrom",
        "dropout": 0.0,
        "num_landmarks": num_landmarks,
        "num_heads": num_heads,
        "pinverse_original_init": pinverse_original_init,
        "use_razavi_pinverse": use_razavi_pinverse,
    }

    sdp_config = {
        "name": "scaled_dot_product",
        "dropout": 0.0,
    }

    a = torch.rand(b, s, d)

    def test_close_to_sdp():
        # Make sure that Nystrom and Normal attention are not too far off.

        nystrom_attention = NystromAttention(**nystrom_config)
        sdp_attention = ScaledDotProduct(**sdp_config)

        r_nystrom = nystrom_attention(a, a, a, att_mask=None)
        r_sdp = sdp_attention(a, a, a, att_mask=None)

        assert torch.allclose(r_nystrom, r_sdp, rtol=0.005, atol=1e-2)

        # Make sure that Nystrom and Normal attention are not too far off.

        nystrom_attention = NystromAttention(**nystrom_config)
        sdp_attention = ScaledDotProduct(**sdp_config)

        r_nystrom = nystrom_attention(a, a, a, att_mask=None)
        r_sdp = sdp_attention(a, a, a, att_mask=None)

        assert torch.allclose(r_nystrom, r_sdp, rtol=0.005, atol=1e-2)

    test_close_to_sdp()


@pytest.mark.parametrize("pinverse_original_init", [True])
@pytest.mark.parametrize("use_razavi_pinverse", [True])
@pytest.mark.parametrize("num_landmarks", [30])
def test_nystrom_attention(
    pinverse_original_init: bool,
    use_razavi_pinverse: bool,
    num_landmarks: int,
):
    # TODO: conv_kernel_size parameter not set to None fails this test. Investigate.
    b, s, d = 2, 900, 40
    num_heads = 2
    seed = 42
    torch.random.manual_seed(seed)
    random.seed(seed)

    nystrom_config = {
        "name": "nystrom",
        "dropout": 0.0,
        "num_landmarks": num_landmarks,
        "num_heads": num_heads,
        "pinverse_original_init": pinverse_original_init,
        "use_razavi_pinverse": use_razavi_pinverse,
    }

    sdp_config = {
        "name": "scaled_dot_product",
        "dropout": 0.0,
    }

    a = torch.rand(b, s, d)

    def test_att_mask_ignored():
        # If an sxs attention mask is passed in, it should be ignored.
        # Results should be the same as if no mask was passed in.
        nystrom_attention = NystromAttention(**nystrom_config)
        sdp_attention = ScaledDotProduct(**sdp_config)

        key_padding_mask = None
        att_mask = torch.randint(0, 2, (s, s)).to(dtype=torch.bool)
        sdp_mask = maybe_merge_masks(
            att_mask=None,
            key_padding_mask=key_padding_mask,
            batch_size=b // num_heads,
            src_len=s,
            num_heads=num_heads,
        )
        r_nystrom = nystrom_attention(
            a, a, a, att_mask=att_mask, key_padding_mask=key_padding_mask
        )
        r_sdp = sdp_attention(a, a, a, att_mask=sdp_mask)
        assert torch.allclose(r_nystrom, r_sdp, rtol=0.005, atol=1e-2)

    def test_masking():
        # FIXME
        # nystrom_config["causal"] = True
        # sdp_config["causal"] = True

        nystrom_attention = NystromAttention(**nystrom_config)
        sdp_attention = ScaledDotProduct(**sdp_config)

        key_padding_mask = torch.rand((b // num_heads, s)) > 0.1
        att_mask = None
        mask = maybe_merge_masks(
            att_mask,
            key_padding_mask,
            batch_size=b // num_heads,
            src_len=s,
            num_heads=num_heads,
        )
        r_nystrom = nystrom_attention(a, a, a, key_padding_mask=key_padding_mask)
        r_sdp = sdp_attention(a, a, a, att_mask=mask)

        # Not very close, but more so testing functionality.
        assert torch.allclose(
            r_nystrom, r_sdp, rtol=0.1, atol=0.5
        ), f"max diff {torch.max(torch.abs(r_nystrom-r_sdp))}"

        # Error when key padding mask doesn't have expected dimensions.
        key_padding_mask = torch.randint(0, 2, (s, b)).to(dtype=torch.bool)
        with pytest.raises(AssertionError):
            nystrom_attention(a, a, a, key_padding_mask=key_padding_mask)

    test_att_mask_ignored()
    test_masking()