Spaces:
Runtime error
Runtime error
File size: 1,259 Bytes
e202b16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from xformers.components.attention.utils import (
maybe_merge_masks,
reshape_key_padding_mask,
)
def test_reshape_key_padding_mask():
batch_size = 2
num_heads = 2
seq_len = 4
batched_dim = batch_size * num_heads
key_padding_mask = torch.randint(0, 2, (batch_size, seq_len)).to(dtype=torch.bool)
reshaped_mask = reshape_key_padding_mask(
key_padding_mask=key_padding_mask, batched_dim=batched_dim
)
assert reshaped_mask.size() == (batched_dim, 1, seq_len)
merged_mask = maybe_merge_masks(
att_mask=None,
key_padding_mask=key_padding_mask,
batch_size=batch_size,
src_len=seq_len,
num_heads=num_heads,
)
assert torch.equal(merged_mask, reshaped_mask.expand(-1, seq_len, -1))
key_padding_mask = torch.randint(0, 2, (batched_dim, seq_len)).to(dtype=torch.bool)
reshaped_mask = reshape_key_padding_mask(
key_padding_mask=key_padding_mask, batched_dim=batched_dim
)
assert reshaped_mask.size() == (batched_dim, 1, seq_len)
|