Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import sys | |
import pytest | |
import torch | |
import xformers | |
try: | |
import triton | |
import triton.language as tl | |
from xformers.triton.vararg_kernel import unroll_varargs | |
_triton_available = xformers._is_triton_available() | |
except ImportError as e: | |
logging.warning( | |
f"Triton is not available, some optimizations will not be tested.\n{e}" | |
) | |
_triton_available = False | |
enable_tests = ( | |
(sys.version_info.major, sys.version_info.minor) >= (3, 9) | |
and _triton_available | |
and torch.cuda.is_available() | |
) | |
def test_triton_varargs_kernel(): | |
def sumN(output_ptr, scaling_ptr, *inputs, BLOCK_SIZE: tl.constexpr): | |
offset = tl.arange(0, BLOCK_SIZE) | |
output = tl.zeros([BLOCK_SIZE], tl.float32) | |
scaling: "VAR_ARGS_ARRAY" # type: ignore # noqa: F821 | |
for i in range(len(scaling)): | |
scaling[i] = tl.load(scaling_ptr + i) | |
for i in range(2): | |
for j in range(len(inputs)): | |
output = output + tl.load(inputs[j] + offset) * scaling[j] | |
tl.store(output_ptr + offset, output) | |
BLOCK_SIZE = 32 | |
NUM_INPUTS = 2 | |
torch.manual_seed(0) | |
inputs = [ | |
torch.randn([BLOCK_SIZE], dtype=torch.float32, device="cuda") | |
for _ in range(NUM_INPUTS) | |
] | |
output = torch.randn([BLOCK_SIZE], dtype=torch.float32, device="cuda") | |
scaling = torch.randn([NUM_INPUTS, 1], dtype=torch.float32, device="cuda") | |
sumN_unrolled = unroll_varargs(sumN, N=NUM_INPUTS) | |
sumN_unrolled[(1,)](output, scaling, *inputs, BLOCK_SIZE=32) | |
assert torch.allclose((2 * torch.stack(inputs) * scaling).sum(0), output) | |
def test_triton_multiple_varargs_kernel(): | |
def weighted_sumN( | |
output_ptr, | |
a_ptr: "VAR_ARGS_ARRAY", # type: ignore # noqa: F821 | |
b: "VAR_ARGS_ARRAY", # type: ignore # noqa: F821 | |
BLOCK_SIZE: tl.constexpr, | |
): | |
# Weighted sum, where the weights are on CPU | |
offset = tl.arange(0, BLOCK_SIZE) | |
output = tl.zeros([BLOCK_SIZE], tl.float32) | |
for i in range(len(a_ptr)): | |
output = output + tl.load(a_ptr[i] + offset) * b[i] | |
tl.store(output_ptr + offset, output) | |
BLOCK_SIZE = 32 | |
NUM_INPUTS = 2 | |
torch.manual_seed(0) | |
a = [ | |
torch.randn([BLOCK_SIZE], dtype=torch.float32, device="cuda") | |
for _ in range(NUM_INPUTS) | |
] | |
b = [torch.randn([], dtype=torch.float32, device="cuda") for _ in range(NUM_INPUTS)] | |
b_list = [x.item() for x in b] | |
output = torch.randn([BLOCK_SIZE], dtype=torch.float32, device="cuda") | |
kernel = unroll_varargs(weighted_sumN, N=NUM_INPUTS) | |
kernel[(1,)](output, *a, *b_list, BLOCK_SIZE=32) | |
expected_output = (torch.stack(a) * torch.stack(b).unsqueeze(1)).sum(0) | |
assert torch.allclose(expected_output, output) | |