Spaces:
Runtime error
Runtime error
File size: 3,205 Bytes
e202b16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import random
import pytest
import torch
import xformers.ops
from xformers.ops.common import _get_storage_base
@pytest.mark.parametrize("contiguous", [True, False])
@pytest.mark.parametrize("dim", [0, 1, 2, 3, 4])
def test_unbind(dim: int, contiguous: bool):
x = torch.randn([10, 20, 4, 10, 3])
x2 = x.clone()
if not contiguous:
perm = list(range(x.ndim))
random.Random(dim).shuffle(perm)
# Let's hope we didn't pick identity
x = x.permute(perm)
x2 = x2.permute(perm)
assert contiguous == x.is_contiguous()
x.requires_grad_(True)
x2.requires_grad_(True)
# FW
tensors = xformers.ops.unbind(x, dim)
tensors2 = torch.unbind(x2, dim)
assert len(tensors) == len(tensors2)
for t1, t2 in zip(tensors, tensors2):
assert torch.allclose(t1, t2)
# BW
grads = torch.unbind(torch.randn(x.shape), dim)
zero = torch.zeros_like(tensors[0])
loss1 = sum(((g * t) for (g, t) in zip(grads, tensors)), zero)
loss2 = sum(((g * t) for (g, t) in zip(grads, tensors2)), zero)
assert torch.allclose(loss1, loss2)
g = torch.randn_like(loss1)
loss1.backward(g)
loss2.backward(g)
assert x.grad is not None
assert x2.grad is not None
assert torch.allclose(x.grad, x2.grad)
@pytest.mark.parametrize("contiguous", [True, False])
@pytest.mark.parametrize("dim", [0, 1, 2, 3, 4])
def test_unbind_get_stack_strides(dim: int, contiguous: bool):
def not_stacked(t, d):
return xformers.ops.get_stack_strides(t, d) is None
x = torch.randn([10, 20, 4, 4, 3])
ndim = x.ndim
# Non-contiguous tensors
if not contiguous:
x = x.transpose(dim, (dim + 1) % ndim)
assert contiguous == x.is_contiguous()
tensors = xformers.ops.unbind(x, dim)
tensors2 = torch.unbind(x.clone(), dim)
for cat_dim in range(ndim):
permute = list(range(ndim))
permute.pop(dim)
permute.insert(cat_dim, dim)
x_permuted = x.permute(permute)
assert not_stacked([tensors2[0], tensors[1]], cat_dim), "different storage"
assert not_stacked(
[tensors[0], tensors[1].clone()], cat_dim
), "different storage"
def test_slice(s):
slices = [slice(None) for _ in range(ndim)]
slices[cat_dim] = s
reference = x_permuted[tuple(slices)]
stacked = xformers.ops.stack_or_none(tensors[s], cat_dim)
assert stacked is not None
assert (
xformers.ops.get_stack_strides(tensors[s], cat_dim)
== reference.stride()
)
assert torch.allclose(stacked, torch.stack(tensors2[s], cat_dim))
assert _get_storage_base(stacked) == _get_storage_base(tensors[0])
# tensors
test_slice(slice(None))
# tensors[1:]
test_slice(slice(1, None))
# tensors[:2]
test_slice(slice(None, 2))
# tensors[::2]
test_slice(slice(None, None, 2))
|