File size: 3,818 Bytes
e202b16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import random

import pytest
import torch

import xformers.ops as xops
from xformers.ops import indexing

from .utils import assert_allclose


@pytest.mark.skipif(
    not indexing.ScaledIndexAddFw.is_available(), reason="not available"
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
@pytest.mark.parametrize("with_scaling", [False, True])
@pytest.mark.parametrize(
    "out_shape", [(48, 1, 257 * 1536), (48, 257, 1536), (192, 50, 1536)]
)
def test_scaled_index_add(out_shape, with_scaling: bool) -> None:
    torch.manual_seed(0)
    alpha = 0.73
    dtype = torch.float16
    B_out, M, D = out_shape
    B_src = int(B_out * 0.6)

    inp = torch.randn([B_out, M, D], device="cuda", dtype=dtype, requires_grad=True)
    src = torch.randn([B_src, M, D], device="cuda", dtype=dtype, requires_grad=True)
    TENSORS = {"inp": inp, "src": src}

    index_py = [i for i in range(src.shape[0])]
    random.Random(B_out).shuffle(index_py)
    index = torch.tensor(index_py, dtype=torch.int64, device="cuda")

    if with_scaling:
        scaling = torch.randn([D], device="cuda", dtype=dtype, requires_grad=True)
        TENSORS["scaling"] = scaling
        ref_src_scaled = scaling.float() * src.float()
    else:
        scaling = None
        ref_src_scaled = src.float()
    ref_out = torch.index_add(
        inp.float(), dim=0, source=ref_src_scaled, index=index, alpha=alpha
    ).to(dtype)
    grad_output = torch.randn_like(ref_out)
    ref_out.backward(grad_output)
    ref_grads = {k: v.grad for k, v in TENSORS.items()}
    for v in TENSORS.values():
        v.grad = None

    # Test FW
    out = xops.scaled_index_add(
        inp.clone(),
        index,
        src,
        scaling,
        alpha,
    )
    assert_allclose(out, ref_out, "fw", atol=4e-3, rtol=1e-3)
    # Test BW
    out.backward(grad_output)
    for k, v in TENSORS.items():
        atol = 1e-5
        rtol = 1e-5
        # NOTE: Ordering of operations is not 100% the same as PT, hence the small numeric diff
        if k == "scaling":
            atol, rtol = 5e-2, 1e-2
        assert_allclose(v.grad, ref_grads[k], f"{k}.grad", atol=atol, rtol=rtol)  # type: ignore


@pytest.mark.skipif(not indexing.IndexSelect.is_available(), reason="not available")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
@pytest.mark.parametrize("D", [1536])
@pytest.mark.parametrize("batches", [((48, 25), (192, 50))])
def test_index_select_cat(D, batches) -> None:
    torch.manual_seed(0)
    dtype = torch.float16

    num_rows = 0
    for B, seqlen in batches:
        num_rows += B * seqlen

    src = torch.randn([num_rows, D], device="cuda", dtype=dtype, requires_grad=True)
    indices = []
    sources = []
    rows_begin = 0
    for B, seqlen in batches:
        index = [i for i in range(B)]
        random.Random(B).shuffle(index)
        indices.append(
            torch.tensor(index[: int(0.6 * B)], dtype=torch.int64, device="cuda")
        )
        sources.append(
            src[rows_begin : rows_begin + B * seqlen].reshape([B, seqlen * D])
        )
        rows_begin += B * seqlen

    # PT implem
    ref_out = torch.cat([s[i].flatten() for s, i in zip(sources, indices)], dim=0)
    gradient_out = torch.randn_like(ref_out)
    ref_out.backward(gradient_out)
    assert src.grad is not None
    ref_grad = src.grad.clone()
    src.grad = None

    # xFormers implem
    out = xops.index_select_cat(sources, indices)
    assert_allclose(out, ref_out, "fw")
    out.backward(gradient_out)
    assert src.grad is not None
    assert_allclose(src.grad, ref_grad, "src.grad")