Flexstorydiff / xformers /tests /test_attention_mask.py
FlexTheAi's picture
Upload folder using huggingface_hub
e202b16 verified
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from xformers.components.attention import AttentionMask
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="This test requires a CUDA device"
)
def test_mask_creation():
# Check that we can create from boolean
bool_mask = torch.rand((256, 256)) > 0.5
additive_mask = AttentionMask.from_bool(bool_mask)
assert (bool_mask == additive_mask.to_bool()).all()
bool_mask = torch.rand((2, 256, 256)) > 0.5
additive_mask = AttentionMask.from_bool(bool_mask)
assert (bool_mask == additive_mask.to_bool()).all()
assert additive_mask.ndim == bool_mask.ndim
# Check that we can create from multiplicative
ref_mask = torch.randint(0, 2, (256, 256))
mul_mask = ref_mask.float()
additive_mask = AttentionMask.from_multiplicative(mul_mask)
assert (ref_mask.bool() == additive_mask.to_bool()).all()
# Check the causal mask
causal_mask = AttentionMask.make_causal(256, 256)
assert (torch.tril(torch.ones(256, 256)).bool() == causal_mask.to_bool()).all()
assert causal_mask.is_causal
causal_mask = AttentionMask.make_causal(256)
assert (torch.tril(torch.ones(256, 256)).bool() == causal_mask.to_bool()).all()
causal_mask = AttentionMask.make_causal(256, 128)
assert (torch.tril(torch.ones(256, 128)).bool() == causal_mask.to_bool()).all()
# Check that we can add masks
bool_mask_1 = torch.rand((256, 256)) > 0.5
add_mask_1 = AttentionMask.from_bool(bool_mask_1)
bool_mask_2 = torch.rand((256, 256)) > 0.5
add_mask_2 = AttentionMask.from_bool(bool_mask_2)
assert ((add_mask_1 + add_mask_2).to_bool() == (bool_mask_1 & bool_mask_2)).all()
# Check type handling
additive_mask = AttentionMask.from_bool(torch.rand((256, 256)) > 0.5)
additive_mask = additive_mask.to(device=torch.device("cuda"))
assert "cuda" in str(additive_mask.values.device)
# Check that the causal flag is maintained
additive_mask = AttentionMask.make_causal(256, 256)
additive_mask = additive_mask.to(device=torch.device("cuda"))
assert additive_mask.is_causal