File size: 6,848 Bytes
e202b16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import os
import random
from typing import Tuple

import pytest
import torch

from xformers import _is_triton_available
from xformers.ops import fused_allgather_and_linear, fused_linear_and_reducescatter

from .multiprocessing_utils import launch_subprocesses

compute_capability = (0, 0)
if torch.cuda.is_available():
    compute_capability = torch.cuda.get_device_capability("cuda")
cuda_sm70_only = pytest.mark.skipif(
    compute_capability < (7, 0), reason="requires sm70+"
)
at_least_2_gpus = pytest.mark.skipif(
    torch.cuda.device_count() < 2, reason="needs at least 2 GPUs"
)


# We care about correctness, not performance, hence let's "disable" the
# expensive autotuning by removing all configs except one (the first one).
if _is_triton_available():
    from xformers.ops._triton.sequence_parallel_fused_kernels import (
        _xformers_seqpar_matmul_kernel,
    )

    while len(_xformers_seqpar_matmul_kernel.configs) > 1:
        _xformers_seqpar_matmul_kernel.configs.pop()


def compare_fused_and_non_fused_ops(
    my_rank: int,
    world_size: int,
    subgroup: torch.distributed.ProcessGroup,
    step: str,
    dims: Tuple[int, ...],
    dtype: torch.dtype,
    triton: bool,
):
    batch_dims = dims[:-2]
    subbatch_dims = (batch_dims[0] // world_size,) + batch_dims[1:]
    outer_dim = dims[-2]
    inner_dim = dims[-1]

    # To check for correctness we want to compare the outputs but the accuracy
    # of matmuls, apparently, is not that great. We thus try to produce inputs
    # for which no rounding at all will occur. We do this by using zero or one
    # inputs, so their product will also be zero or one, and keep the reduction
    # dimension small enough so that they fit in the mantissa without overflow.
    max_exact_value = 2 * (1 / torch.finfo(dtype).eps)
    # 0.25 is the ratio of expected ones and we aim at 2/3 of the safe range
    assert outer_dim * 0.25 <= max_exact_value * 0.66
    assert inner_dim * world_size * 0.25 <= max_exact_value * 0.66

    if step == "all-gather":
        inputs = torch.testing.make_tensor(
            (world_size,) + subbatch_dims + (outer_dim,),
            dtype=dtype,
            device="cuda",
            low=0,
            high=1,
        ).round()
        weight = torch.testing.make_tensor(
            (inner_dim, outer_dim), dtype=dtype, device="cuda", low=0, high=1
        ).round()

        # Non-fused reference code
        output_reference = torch.matmul(inputs, weight.t()).flatten(0, 1)

        # Faster fused mode
        output_fused = fused_allgather_and_linear(
            inputs[my_rank], weight, group=subgroup, _triton=triton
        )

    elif step == "reduce-scatter":
        inputs = torch.testing.make_tensor(
            (world_size,) + batch_dims + (inner_dim,),
            dtype=dtype,
            device="cuda",
            low=0,
            high=1,
        ).round()
        weights = torch.testing.make_tensor(
            (world_size, outer_dim, inner_dim),
            dtype=dtype,
            device="cuda",
            low=0,
            high=1,
        ).round()

        # Non-fused reference code
        staging = torch.empty(
            (world_size,) + subbatch_dims + (outer_dim,), dtype=dtype, device="cuda"
        )
        for rank in range(world_size):
            torch.matmul(
                inputs[rank].tensor_split(world_size, dim=0)[my_rank],
                weights[rank].t(),
                out=staging[rank],
            )
        output_reference = torch.sum(staging, dim=0, dtype=dtype)

        # Faster fused mode
        output_fused = fused_linear_and_reducescatter(
            inputs[my_rank], weights[my_rank], group=subgroup, _triton=triton
        )

    torch.testing.assert_close(output_reference, output_fused, atol=0, rtol=0)


def inner_sequence_parallel_fused(
    seed: int,
    kind: str,
    step: str,
    dims: Tuple[int, ...],
    dtype: torch.dtype,
):
    my_rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
    subgroup = torch.distributed.new_group()

    triton = True
    if kind == "fallback":
        os.environ["DISABLE_FUSED_SEQUENCE_PARALLEL"] = "1"
    elif kind == "pytorch":
        triton = False

    torch.random.manual_seed(seed)

    compare_fused_and_non_fused_ops(
        my_rank=my_rank,
        world_size=world_size,
        subgroup=subgroup,
        step=step,
        dims=dims,
        dtype=dtype,
        triton=triton,
    )


@cuda_sm70_only
@pytest.mark.parametrize(
    "kind",
    ["singleton", pytest.param("fallback", marks=at_least_2_gpus), "pytorch", "triton"],
)
@pytest.mark.parametrize("step", ["all-gather", "reduce-scatter"])
@pytest.mark.parametrize(
    "dims",
    [
        pytest.param((2, 2, 512, 512, 256), id="nice-shapes"),
        pytest.param((2, 1023, 511, 257), id="ugly-shapes"),
    ],
)
@pytest.mark.parametrize(
    "dtype",
    [
        pytest.param(torch.bfloat16, id="bf16"),
        pytest.param(torch.float16, id="fp16"),
        pytest.param(torch.float32, id="fp32"),
    ],
)
def test_sequence_parallel_fused(
    kind: str,
    step: str,
    dims: Tuple[int, ...],
    dtype: torch.dtype,
):
    world_size = 1 if kind == "singleton" else 2
    seed = random.getrandbits(32)
    launch_subprocesses(
        world_size,
        inner_sequence_parallel_fused,
        seed=seed,
        kind=kind,
        step=step,
        dims=dims,
        dtype=dtype,
    )


def inner_sequence_parallel_fused_triton_handle_all_dtypes(
    seed: int,
    step: str,
    dims: Tuple[int, ...],
):
    my_rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
    subgroup = torch.distributed.new_group()

    torch.random.manual_seed(seed)

    for dtype in [torch.bfloat16, torch.float16, torch.float32]:
        compare_fused_and_non_fused_ops(
            my_rank=my_rank,
            world_size=world_size,
            subgroup=subgroup,
            step=step,
            dims=dims,
            dtype=dtype,
            triton=True,
        )


@cuda_sm70_only
@pytest.mark.parametrize("step", ["all-gather", "reduce-scatter"])
@pytest.mark.parametrize(
    "dims",
    [
        pytest.param((2, 2, 512, 512, 256), id="nice-shapes"),
        pytest.param((2, 1023, 511, 257), id="ugly-shapes"),
    ],
)
def test_sequence_parallel_fused_triton_handle_all_dtypes(
    step: str,
    dims: Tuple[int, ...],
):
    world_size = 2
    seed = random.getrandbits(32)
    launch_subprocesses(
        world_size,
        inner_sequence_parallel_fused_triton_handle_all_dtypes,
        seed=seed,
        step=step,
        dims=dims,
    )