File size: 5,510 Bytes
e202b16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import random

import pytest
import torch

from xformers import _is_triton_available
from xformers.ops.tiled_matmul import tiled_matmul

cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
compute_capability = (0, 0)
if torch.cuda.is_available():
    compute_capability = torch.cuda.get_device_capability("cuda")
cuda_sm70_only = pytest.mark.skipif(
    compute_capability < (7, 0), reason="requires sm70+"
)

# We care about correctness, not performance, hence let's "disable" the
# expensive autotuning by removing all configs except one (the first one).
if _is_triton_available():
    from xformers.ops._triton.tiled_matmul_kernels import _xformers_tiled_matmul_kernel

    while len(_xformers_tiled_matmul_kernel.configs) > 1:
        _xformers_tiled_matmul_kernel.configs.pop()


def generate_test_shapes(*repeats, num_shapes=5):
    shapes = []
    r = random.Random(0)
    for repeat in repeats:
        m_num_tiles, n_num_tiles, k_num_tiles = repeat
        for _ in range(num_shapes):
            shapes.append(
                (
                    [r.randint(2, 1024 // m_num_tiles) for _ in range(m_num_tiles)],
                    [r.randint(2, 1024 // n_num_tiles) for _ in range(n_num_tiles)],
                    [r.randint(2, 1024 // k_num_tiles) for _ in range(k_num_tiles)],
                )
            )
    return shapes


_test_shapes = generate_test_shapes((1, 1, 1), (3, 3, 3))
_dtypes = [torch.float32, torch.bfloat16, torch.float16]


def ceil_of_ratio(n, k):
    return (n + k - 1) // k


def make_operands(m, n, k, *, dtype):
    """Produce lhs, rhs and reference output tensors

    To dodge numerical accuracy differences between our kernels and PyTorch's
    ones, we avoid random values and construct matrices whose product is an
    exact mathematical computation, specifically: the remainder!

    We do it by having the i-th row of lhs and the j-th column on rhs be like:
    * lhs: i times "1", followed by "0"
    * rhs: j-1 times "1", followed by "-(j-1)", then repeated
    The running sum of their pointwise product will thus be:
    1, 2, 3, ..., j-1, 0, 1, 2, 3, ... and so on
    And the final value will be remainder of i by j.

    If K is smaller than M and/or N, this function also takes care of repeating
    some rows and/or columns in order to "fill" M and/or K. Similarly, if the
    precision of the dtype is too low to store the result without losses, the
    function will only use small-enough values, and repeat them as needed.

    Finally, the function permutes the rows and columns, in order to avoid a
    predictable block structure.

    """
    max_value = min(k, int(1 / torch.finfo(dtype).eps) * 2)
    m_perm = torch.randperm(m)
    n_perm = torch.randperm(n)

    num_reps_m = ceil_of_ratio(m, max_value)
    lhs = (
        torch.ones((min(m, max_value), k), dtype=dtype)
        .tril()
        .repeat([num_reps_m, 1])[m_perm, :]
    )
    assert lhs.shape == (m, k)

    num_reps_n = ceil_of_ratio(n, max_value)
    rhs = torch.ones((k, min(n, max_value)), dtype=dtype)
    for i in range(2, min(n, max_value) + 2):
        rhs[:, i - 2][i - 1 :: i] = -i + 1
    rhs = rhs.repeat([1, num_reps_n])[:, n_perm]
    assert rhs.shape == (k, n)

    lhs_idxs = torch.arange(1, min(m, max_value) + 1).repeat([num_reps_m])[m_perm, None]
    rhs_idxs = torch.arange(2, min(n, max_value) + 2).repeat([num_reps_n])[None, n_perm]
    out = torch.remainder(lhs_idxs, rhs_idxs).to(dtype)
    assert out.shape == (m, n)

    return lhs, rhs, out


@cuda_only
@cuda_sm70_only
@pytest.mark.parametrize("shape", _test_shapes, ids=[str(x) for x in _test_shapes])
@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes])
def test_forward_backward(
    shape,
    dtype,
):
    m_tiles, n_tiles, k_tiles = shape
    m, n, k = sum(m_tiles), sum(n_tiles), sum(k_tiles)

    torch.manual_seed(m * n * k)

    a, b, c_reference = make_operands(m, n, k, dtype=dtype)
    a = a.cuda().requires_grad_()
    b = b.cuda().requires_grad_()
    c_reference = c_reference.cuda()

    # In one operand make each tile have its own strides, in the other use the
    # same stride for all tiles. And make the two operands have the stride==1
    # in different dimensions.
    a_tiled = [
        [y.t().clone().t() for y in x.split(k_tiles, dim=1)]
        for x in a.split(m_tiles, dim=0)
    ]
    b_tiled = [[y for y in x.split(n_tiles, dim=1)] for x in b.split(k_tiles, dim=0)]

    c_test_tiled = tiled_matmul(a_tiled, b_tiled)
    c_test = torch.cat([torch.cat(x, dim=1) for x in c_test_tiled], dim=0)

    torch.testing.assert_close(c_test, c_reference)

    # To avoid numerical issues in the backward, set things up so that we only
    # multiply by a diagonal matrix whose entries are +/- 2^{-1/0/+1} (so that
    # it only changes the sign bit and the exponent).
    diag = torch.tensor(random.choices([-2, -1, -0.5, 0.5, 1, 2], k=min(m, n)))
    grad_c = torch.zeros_like(c_test)
    torch.diag(grad_c)[:] = diag
    grad_a_reference = torch.matmul(grad_c, b.detach().t())
    grad_b_reference = torch.matmul(a.detach().t(), grad_c)

    torch.autograd.backward([c_test], [grad_c], inputs=[a, b])

    torch.testing.assert_close(a.grad, grad_a_reference)
    torch.testing.assert_close(b.grad, grad_b_reference)