Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import random | |
import pytest | |
import torch | |
import xformers.ops as xops | |
from xformers.ops import indexing | |
from .utils import assert_allclose | |
def test_scaled_index_add(out_shape, with_scaling: bool) -> None: | |
torch.manual_seed(0) | |
alpha = 0.73 | |
dtype = torch.float16 | |
B_out, M, D = out_shape | |
B_src = int(B_out * 0.6) | |
inp = torch.randn([B_out, M, D], device="cuda", dtype=dtype, requires_grad=True) | |
src = torch.randn([B_src, M, D], device="cuda", dtype=dtype, requires_grad=True) | |
TENSORS = {"inp": inp, "src": src} | |
index_py = [i for i in range(src.shape[0])] | |
random.Random(B_out).shuffle(index_py) | |
index = torch.tensor(index_py, dtype=torch.int64, device="cuda") | |
if with_scaling: | |
scaling = torch.randn([D], device="cuda", dtype=dtype, requires_grad=True) | |
TENSORS["scaling"] = scaling | |
ref_src_scaled = scaling.float() * src.float() | |
else: | |
scaling = None | |
ref_src_scaled = src.float() | |
ref_out = torch.index_add( | |
inp.float(), dim=0, source=ref_src_scaled, index=index, alpha=alpha | |
).to(dtype) | |
grad_output = torch.randn_like(ref_out) | |
ref_out.backward(grad_output) | |
ref_grads = {k: v.grad for k, v in TENSORS.items()} | |
for v in TENSORS.values(): | |
v.grad = None | |
# Test FW | |
out = xops.scaled_index_add( | |
inp.clone(), | |
index, | |
src, | |
scaling, | |
alpha, | |
) | |
assert_allclose(out, ref_out, "fw", atol=4e-3, rtol=1e-3) | |
# Test BW | |
out.backward(grad_output) | |
for k, v in TENSORS.items(): | |
atol = 1e-5 | |
rtol = 1e-5 | |
# NOTE: Ordering of operations is not 100% the same as PT, hence the small numeric diff | |
if k == "scaling": | |
atol, rtol = 5e-2, 1e-2 | |
assert_allclose(v.grad, ref_grads[k], f"{k}.grad", atol=atol, rtol=rtol) # type: ignore | |
def test_index_select_cat(D, batches) -> None: | |
torch.manual_seed(0) | |
dtype = torch.float16 | |
num_rows = 0 | |
for B, seqlen in batches: | |
num_rows += B * seqlen | |
src = torch.randn([num_rows, D], device="cuda", dtype=dtype, requires_grad=True) | |
indices = [] | |
sources = [] | |
rows_begin = 0 | |
for B, seqlen in batches: | |
index = [i for i in range(B)] | |
random.Random(B).shuffle(index) | |
indices.append( | |
torch.tensor(index[: int(0.6 * B)], dtype=torch.int64, device="cuda") | |
) | |
sources.append( | |
src[rows_begin : rows_begin + B * seqlen].reshape([B, seqlen * D]) | |
) | |
rows_begin += B * seqlen | |
# PT implem | |
ref_out = torch.cat([s[i].flatten() for s, i in zip(sources, indices)], dim=0) | |
gradient_out = torch.randn_like(ref_out) | |
ref_out.backward(gradient_out) | |
assert src.grad is not None | |
ref_grad = src.grad.clone() | |
src.grad = None | |
# xFormers implem | |
out = xops.index_select_cat(sources, indices) | |
assert_allclose(out, ref_out, "fw") | |
out.backward(gradient_out) | |
assert src.grad is not None | |
assert_allclose(src.grad, ref_grad, "src.grad") | |