Spaces:
Runtime error
Runtime error
File size: 5,510 Bytes
e202b16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import random
import pytest
import torch
from xformers import _is_triton_available
from xformers.ops.tiled_matmul import tiled_matmul
cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
compute_capability = (0, 0)
if torch.cuda.is_available():
compute_capability = torch.cuda.get_device_capability("cuda")
cuda_sm70_only = pytest.mark.skipif(
compute_capability < (7, 0), reason="requires sm70+"
)
# We care about correctness, not performance, hence let's "disable" the
# expensive autotuning by removing all configs except one (the first one).
if _is_triton_available():
from xformers.ops._triton.tiled_matmul_kernels import _xformers_tiled_matmul_kernel
while len(_xformers_tiled_matmul_kernel.configs) > 1:
_xformers_tiled_matmul_kernel.configs.pop()
def generate_test_shapes(*repeats, num_shapes=5):
shapes = []
r = random.Random(0)
for repeat in repeats:
m_num_tiles, n_num_tiles, k_num_tiles = repeat
for _ in range(num_shapes):
shapes.append(
(
[r.randint(2, 1024 // m_num_tiles) for _ in range(m_num_tiles)],
[r.randint(2, 1024 // n_num_tiles) for _ in range(n_num_tiles)],
[r.randint(2, 1024 // k_num_tiles) for _ in range(k_num_tiles)],
)
)
return shapes
_test_shapes = generate_test_shapes((1, 1, 1), (3, 3, 3))
_dtypes = [torch.float32, torch.bfloat16, torch.float16]
def ceil_of_ratio(n, k):
return (n + k - 1) // k
def make_operands(m, n, k, *, dtype):
"""Produce lhs, rhs and reference output tensors
To dodge numerical accuracy differences between our kernels and PyTorch's
ones, we avoid random values and construct matrices whose product is an
exact mathematical computation, specifically: the remainder!
We do it by having the i-th row of lhs and the j-th column on rhs be like:
* lhs: i times "1", followed by "0"
* rhs: j-1 times "1", followed by "-(j-1)", then repeated
The running sum of their pointwise product will thus be:
1, 2, 3, ..., j-1, 0, 1, 2, 3, ... and so on
And the final value will be remainder of i by j.
If K is smaller than M and/or N, this function also takes care of repeating
some rows and/or columns in order to "fill" M and/or K. Similarly, if the
precision of the dtype is too low to store the result without losses, the
function will only use small-enough values, and repeat them as needed.
Finally, the function permutes the rows and columns, in order to avoid a
predictable block structure.
"""
max_value = min(k, int(1 / torch.finfo(dtype).eps) * 2)
m_perm = torch.randperm(m)
n_perm = torch.randperm(n)
num_reps_m = ceil_of_ratio(m, max_value)
lhs = (
torch.ones((min(m, max_value), k), dtype=dtype)
.tril()
.repeat([num_reps_m, 1])[m_perm, :]
)
assert lhs.shape == (m, k)
num_reps_n = ceil_of_ratio(n, max_value)
rhs = torch.ones((k, min(n, max_value)), dtype=dtype)
for i in range(2, min(n, max_value) + 2):
rhs[:, i - 2][i - 1 :: i] = -i + 1
rhs = rhs.repeat([1, num_reps_n])[:, n_perm]
assert rhs.shape == (k, n)
lhs_idxs = torch.arange(1, min(m, max_value) + 1).repeat([num_reps_m])[m_perm, None]
rhs_idxs = torch.arange(2, min(n, max_value) + 2).repeat([num_reps_n])[None, n_perm]
out = torch.remainder(lhs_idxs, rhs_idxs).to(dtype)
assert out.shape == (m, n)
return lhs, rhs, out
@cuda_only
@cuda_sm70_only
@pytest.mark.parametrize("shape", _test_shapes, ids=[str(x) for x in _test_shapes])
@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes])
def test_forward_backward(
shape,
dtype,
):
m_tiles, n_tiles, k_tiles = shape
m, n, k = sum(m_tiles), sum(n_tiles), sum(k_tiles)
torch.manual_seed(m * n * k)
a, b, c_reference = make_operands(m, n, k, dtype=dtype)
a = a.cuda().requires_grad_()
b = b.cuda().requires_grad_()
c_reference = c_reference.cuda()
# In one operand make each tile have its own strides, in the other use the
# same stride for all tiles. And make the two operands have the stride==1
# in different dimensions.
a_tiled = [
[y.t().clone().t() for y in x.split(k_tiles, dim=1)]
for x in a.split(m_tiles, dim=0)
]
b_tiled = [[y for y in x.split(n_tiles, dim=1)] for x in b.split(k_tiles, dim=0)]
c_test_tiled = tiled_matmul(a_tiled, b_tiled)
c_test = torch.cat([torch.cat(x, dim=1) for x in c_test_tiled], dim=0)
torch.testing.assert_close(c_test, c_reference)
# To avoid numerical issues in the backward, set things up so that we only
# multiply by a diagonal matrix whose entries are +/- 2^{-1/0/+1} (so that
# it only changes the sign bit and the exponent).
diag = torch.tensor(random.choices([-2, -1, -0.5, 0.5, 1, 2], k=min(m, n)))
grad_c = torch.zeros_like(c_test)
torch.diag(grad_c)[:] = diag
grad_a_reference = torch.matmul(grad_c, b.detach().t())
grad_b_reference = torch.matmul(a.detach().t(), grad_c)
torch.autograd.backward([c_test], [grad_c], inputs=[a, b])
torch.testing.assert_close(a.grad, grad_a_reference)
torch.testing.assert_close(b.grad, grad_b_reference)
|