Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import random | |
import pytest | |
import torch | |
from xformers import _is_triton_available | |
from xformers.ops.tiled_matmul import tiled_matmul | |
cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") | |
compute_capability = (0, 0) | |
if torch.cuda.is_available(): | |
compute_capability = torch.cuda.get_device_capability("cuda") | |
cuda_sm70_only = pytest.mark.skipif( | |
compute_capability < (7, 0), reason="requires sm70+" | |
) | |
# We care about correctness, not performance, hence let's "disable" the | |
# expensive autotuning by removing all configs except one (the first one). | |
if _is_triton_available(): | |
from xformers.ops._triton.tiled_matmul_kernels import _xformers_tiled_matmul_kernel | |
while len(_xformers_tiled_matmul_kernel.configs) > 1: | |
_xformers_tiled_matmul_kernel.configs.pop() | |
def generate_test_shapes(*repeats, num_shapes=5): | |
shapes = [] | |
r = random.Random(0) | |
for repeat in repeats: | |
m_num_tiles, n_num_tiles, k_num_tiles = repeat | |
for _ in range(num_shapes): | |
shapes.append( | |
( | |
[r.randint(2, 1024 // m_num_tiles) for _ in range(m_num_tiles)], | |
[r.randint(2, 1024 // n_num_tiles) for _ in range(n_num_tiles)], | |
[r.randint(2, 1024 // k_num_tiles) for _ in range(k_num_tiles)], | |
) | |
) | |
return shapes | |
_test_shapes = generate_test_shapes((1, 1, 1), (3, 3, 3)) | |
_dtypes = [torch.float32, torch.bfloat16, torch.float16] | |
def ceil_of_ratio(n, k): | |
return (n + k - 1) // k | |
def make_operands(m, n, k, *, dtype): | |
"""Produce lhs, rhs and reference output tensors | |
To dodge numerical accuracy differences between our kernels and PyTorch's | |
ones, we avoid random values and construct matrices whose product is an | |
exact mathematical computation, specifically: the remainder! | |
We do it by having the i-th row of lhs and the j-th column on rhs be like: | |
* lhs: i times "1", followed by "0" | |
* rhs: j-1 times "1", followed by "-(j-1)", then repeated | |
The running sum of their pointwise product will thus be: | |
1, 2, 3, ..., j-1, 0, 1, 2, 3, ... and so on | |
And the final value will be remainder of i by j. | |
If K is smaller than M and/or N, this function also takes care of repeating | |
some rows and/or columns in order to "fill" M and/or K. Similarly, if the | |
precision of the dtype is too low to store the result without losses, the | |
function will only use small-enough values, and repeat them as needed. | |
Finally, the function permutes the rows and columns, in order to avoid a | |
predictable block structure. | |
""" | |
max_value = min(k, int(1 / torch.finfo(dtype).eps) * 2) | |
m_perm = torch.randperm(m) | |
n_perm = torch.randperm(n) | |
num_reps_m = ceil_of_ratio(m, max_value) | |
lhs = ( | |
torch.ones((min(m, max_value), k), dtype=dtype) | |
.tril() | |
.repeat([num_reps_m, 1])[m_perm, :] | |
) | |
assert lhs.shape == (m, k) | |
num_reps_n = ceil_of_ratio(n, max_value) | |
rhs = torch.ones((k, min(n, max_value)), dtype=dtype) | |
for i in range(2, min(n, max_value) + 2): | |
rhs[:, i - 2][i - 1 :: i] = -i + 1 | |
rhs = rhs.repeat([1, num_reps_n])[:, n_perm] | |
assert rhs.shape == (k, n) | |
lhs_idxs = torch.arange(1, min(m, max_value) + 1).repeat([num_reps_m])[m_perm, None] | |
rhs_idxs = torch.arange(2, min(n, max_value) + 2).repeat([num_reps_n])[None, n_perm] | |
out = torch.remainder(lhs_idxs, rhs_idxs).to(dtype) | |
assert out.shape == (m, n) | |
return lhs, rhs, out | |
def test_forward_backward( | |
shape, | |
dtype, | |
): | |
m_tiles, n_tiles, k_tiles = shape | |
m, n, k = sum(m_tiles), sum(n_tiles), sum(k_tiles) | |
torch.manual_seed(m * n * k) | |
a, b, c_reference = make_operands(m, n, k, dtype=dtype) | |
a = a.cuda().requires_grad_() | |
b = b.cuda().requires_grad_() | |
c_reference = c_reference.cuda() | |
# In one operand make each tile have its own strides, in the other use the | |
# same stride for all tiles. And make the two operands have the stride==1 | |
# in different dimensions. | |
a_tiled = [ | |
[y.t().clone().t() for y in x.split(k_tiles, dim=1)] | |
for x in a.split(m_tiles, dim=0) | |
] | |
b_tiled = [[y for y in x.split(n_tiles, dim=1)] for x in b.split(k_tiles, dim=0)] | |
c_test_tiled = tiled_matmul(a_tiled, b_tiled) | |
c_test = torch.cat([torch.cat(x, dim=1) for x in c_test_tiled], dim=0) | |
torch.testing.assert_close(c_test, c_reference) | |
# To avoid numerical issues in the backward, set things up so that we only | |
# multiply by a diagonal matrix whose entries are +/- 2^{-1/0/+1} (so that | |
# it only changes the sign bit and the exponent). | |
diag = torch.tensor(random.choices([-2, -1, -0.5, 0.5, 1, 2], k=min(m, n))) | |
grad_c = torch.zeros_like(c_test) | |
torch.diag(grad_c)[:] = diag | |
grad_a_reference = torch.matmul(grad_c, b.detach().t()) | |
grad_b_reference = torch.matmul(a.detach().t(), grad_c) | |
torch.autograd.backward([c_test], [grad_c], inputs=[a, b]) | |
torch.testing.assert_close(a.grad, grad_a_reference) | |
torch.testing.assert_close(b.grad, grad_b_reference) | |