Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from xformers.components.attention.utils import ( | |
maybe_merge_masks, | |
reshape_key_padding_mask, | |
) | |
def test_reshape_key_padding_mask(): | |
batch_size = 2 | |
num_heads = 2 | |
seq_len = 4 | |
batched_dim = batch_size * num_heads | |
key_padding_mask = torch.randint(0, 2, (batch_size, seq_len)).to(dtype=torch.bool) | |
reshaped_mask = reshape_key_padding_mask( | |
key_padding_mask=key_padding_mask, batched_dim=batched_dim | |
) | |
assert reshaped_mask.size() == (batched_dim, 1, seq_len) | |
merged_mask = maybe_merge_masks( | |
att_mask=None, | |
key_padding_mask=key_padding_mask, | |
batch_size=batch_size, | |
src_len=seq_len, | |
num_heads=num_heads, | |
) | |
assert torch.equal(merged_mask, reshaped_mask.expand(-1, seq_len, -1)) | |
key_padding_mask = torch.randint(0, 2, (batched_dim, seq_len)).to(dtype=torch.bool) | |
reshaped_mask = reshape_key_padding_mask( | |
key_padding_mask=key_padding_mask, batched_dim=batched_dim | |
) | |
assert reshaped_mask.size() == (batched_dim, 1, seq_len) | |