# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import os import random from typing import Tuple import pytest import torch from xformers import _is_triton_available from xformers.ops import ( sequence_parallel_leading_matmul, sequence_parallel_trailing_matmul, ) from .multiprocessing_utils import launch_subprocesses compute_capability = (0, 0) if torch.cuda.is_available(): compute_capability = torch.cuda.get_device_capability("cuda") cuda_sm70_only = pytest.mark.skipif( compute_capability < (7, 0), reason="requires sm70+" ) at_least_2_gpus = pytest.mark.skipif( torch.cuda.device_count() < 2, reason="needs at least 2 GPUs" ) # We care about correctness, not performance, hence let's "disable" the # expensive autotuning by removing all configs except one (the first one). if _is_triton_available(): from xformers.ops._triton.sequence_parallel_fused_kernels import ( _xformers_seqpar_matmul_kernel, ) while len(_xformers_seqpar_matmul_kernel.configs) > 1: _xformers_seqpar_matmul_kernel.configs.pop() from xformers.ops._triton.tiled_matmul_kernels import _xformers_tiled_matmul_kernel while len(_xformers_tiled_matmul_kernel.configs) > 1: _xformers_tiled_matmul_kernel.configs.pop() def reference_leading(input_, w1, w2): hidden1 = torch.matmul(input_, w1.t()) hidden2 = torch.matmul(input_, w2.t()) return [hidden1, hidden2] def reference_trailing(hidden, w): output = torch.matmul(hidden, w.t()) return output def xformers_leading(input_, w1, w2, *, fuse, group): return sequence_parallel_leading_matmul( input_, [w1.t(), w2.t()], fuse=fuse, process_group=group ) def xformers_trailing(hidden, w, *, fuse, group): return sequence_parallel_trailing_matmul( hidden, w.t(), fuse=fuse, process_group=group ) def inner_seqpar( kind: str, step: str, dims: Tuple[int, ...], dtype: torch.dtype, seed: int, ): my_rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() subgroup = torch.distributed.new_group() fused = True if kind == "unfused": fused = False elif kind == "fallback": os.environ["DISABLE_FUSED_SEQUENCE_PARALLEL"] = "1" torch.random.manual_seed(seed) batch_dims = dims[:-2] outer_dim = dims[-2] inner_dim = dims[-1] # To check for correctness we want to compare the outputs but the accuracy # of matmuls, apparently, is not that great. We thus try to produce inputs # for which no rounding at all will occur. We do this by using zero or one # inputs, so their product will also be zero or one, and keep the reduction # dimension small enough so that they fit in the mantissa without overflow. max_exact_value = 2 * (1 / torch.finfo(dtype).eps) # 0.25 is the ratio of expected ones and we aim at 2/3 of the safe range assert outer_dim * 0.25 <= max_exact_value * 0.66 assert inner_dim * world_size * 0.25 <= max_exact_value * 0.66 def my_chunk(t, *, dim): return t.tensor_split(world_size, dim=dim)[my_rank] if step == "leading": input_ = torch.testing.make_tensor( batch_dims + (outer_dim,), dtype=dtype, device="cuda", low=0, high=1, ).round() weight1, weight2 = [ torch.testing.make_tensor( (inner_dim * (idx + 1), outer_dim), dtype=dtype, device="cuda", low=0, high=1, ).round() for idx in range(2) ] gradient1, gradient2 = [ torch.testing.make_tensor( batch_dims + (inner_dim * (idx + 1),), dtype=dtype, device="cuda", low=0, high=1, ).round() for idx in range(2) ] # Non-fused reference code input_ref = input_.detach().requires_grad_() weight1_ref = weight1.detach().requires_grad_() weight2_ref = weight2.detach().requires_grad_() output1_ref, output2_ref = reference_leading( input_ref, weight1_ref, weight2_ref ) torch.autograd.backward([output1_ref, output2_ref], [gradient1, gradient2]) my_output1_ref = my_chunk(output1_ref, dim=-1) my_output2_ref = my_chunk(output2_ref, dim=-1) my_weight1_grad_ref = my_chunk(weight1_ref.grad, dim=0) my_weight2_grad_ref = my_chunk(weight2_ref.grad, dim=0) my_input_grad_ref = my_chunk(input_ref.grad, dim=0) # Faster fused mode my_input_xf = my_chunk(input_, dim=0).detach().requires_grad_() my_weight1_xf = my_chunk(weight1, dim=0).detach().requires_grad_() my_weight2_xf = my_chunk(weight2, dim=0).detach().requires_grad_() my_gradient1 = my_chunk(gradient1, dim=-1) my_gradient2 = my_chunk(gradient2, dim=-1) my_output1_xf, my_output2_xf = xformers_leading( my_input_xf, my_weight1_xf, my_weight2_xf, fuse=fused, group=subgroup ) torch.autograd.backward( [my_output1_xf, my_output2_xf], [my_gradient1, my_gradient2] ) my_weight1_grad_xf = my_weight1_xf.grad my_weight2_grad_xf = my_weight2_xf.grad my_input_grad_xf = my_input_xf.grad # Checks torch.testing.assert_close(my_output1_ref, my_output1_xf) torch.testing.assert_close(my_output2_ref, my_output2_xf) torch.testing.assert_close(my_input_grad_ref, my_input_grad_xf) torch.testing.assert_close(my_weight1_grad_ref, my_weight1_grad_xf) torch.testing.assert_close(my_weight2_grad_ref, my_weight2_grad_xf) elif step == "trailing": input_ = torch.testing.make_tensor( batch_dims + (inner_dim,), dtype=dtype, device="cuda", low=0, high=1, ).round() weight = torch.testing.make_tensor( (outer_dim, inner_dim), dtype=dtype, device="cuda", low=0, high=1, ).round() gradient = torch.testing.make_tensor( batch_dims + (outer_dim,), dtype=dtype, device="cuda", low=0, high=1, ).round() # Non-fused reference code input_ref = input_.detach().requires_grad_() weight_ref = weight.detach().requires_grad_() output_ref = reference_trailing(input_ref, weight_ref) torch.autograd.backward([output_ref], [gradient]) my_output_ref = my_chunk(output_ref, dim=0) my_weight_grad_ref = my_chunk(weight_ref.grad, dim=1) my_input_grad_ref = my_chunk(input_ref.grad, dim=-1) # Faster fused mode my_input_xf = my_chunk(input_, dim=-1).detach().clone().requires_grad_() my_weight_xf = my_chunk(weight, dim=1).detach().requires_grad_() my_gradient = my_chunk(gradient, dim=0) my_output_xf = xformers_trailing( my_input_xf, my_weight_xf, fuse=fused, group=subgroup ) torch.autograd.backward([my_output_xf], [my_gradient]) my_weight_grad_xf = my_weight_xf.grad my_input_grad_xf = my_input_xf.grad # Checks torch.testing.assert_close(my_output_ref, my_output_xf) torch.testing.assert_close(my_input_grad_ref, my_input_grad_xf) torch.testing.assert_close(my_weight_grad_ref, my_weight_grad_xf) @cuda_sm70_only @pytest.mark.parametrize( "kind", [ "singleton", pytest.param("unfused", marks=at_least_2_gpus), pytest.param("fallback", marks=at_least_2_gpus), "fused", ], ) @pytest.mark.parametrize( "step", [ "leading", "trailing", ], ) @pytest.mark.parametrize( "dims", [ pytest.param((2, 2, 512, 512, 256), id="nice-shapes"), pytest.param((2, 1023, 511, 257), id="ugly-shapes"), ], ) @pytest.mark.parametrize( "dtype", [ pytest.param(torch.bfloat16, id="bf16"), pytest.param(torch.float16, id="fp16"), pytest.param(torch.float32, id="fp32"), ], ) def test_seqpar( kind: str, step: str, dims: Tuple[int, ...], dtype: torch.dtype, ): world_size = 1 if kind == "singleton" else 2 seed = random.getrandbits(32) launch_subprocesses( world_size=world_size, fn=inner_seqpar, kind=kind, step=step, dims=dims, dtype=dtype, seed=seed, )