Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import pytest | |
import torch | |
import torch.distributed as dist | |
from xformers.ops import init_ipc | |
from .multiprocessing_utils import launch_subprocesses | |
compute_capability = (0, 0) | |
if torch.cuda.is_available(): | |
compute_capability = torch.cuda.get_device_capability("cuda") | |
cuda_sm70_only = pytest.mark.skipif( | |
compute_capability < (7, 0), reason="requires sm70+" | |
) | |
def inner_test_ipc() -> None: | |
my_rank = torch.distributed.get_rank() | |
world_size = torch.distributed.get_world_size() | |
subgroup = torch.distributed.new_group() | |
ipcs = init_ipc(subgroup) | |
send_bufs = [ | |
torch.full([1], my_rank, device="cuda", dtype=torch.int32) | |
for _ in range(world_size) | |
] | |
recv_bufs = send_bufs.copy() | |
for other_rank, conn in enumerate(ipcs): | |
if conn is None: | |
continue | |
conn.send(send_bufs[other_rank]) | |
for other_rank, conn in enumerate(ipcs): | |
if conn is None: | |
continue | |
recv_bufs[other_rank] = conn.recv() | |
torch.cuda.synchronize() | |
dist.barrier(subgroup) | |
# Use the buffer to send data | |
for other_rank, buf in enumerate(recv_bufs): | |
assert buf[0].item() == other_rank | |
buf.fill_(my_rank) | |
torch.cuda.synchronize() | |
dist.barrier(subgroup) | |
# Verify we've received the data correctly | |
for other_rank, buf in enumerate(send_bufs): | |
assert ( | |
buf[0].item() == other_rank | |
), f"[#{my_rank}] {other_rank=} != {buf[0].item()=}" | |
def test_ipc() -> None: | |
world_size = 2 | |
launch_subprocesses( | |
world_size=world_size, | |
fn=inner_test_ipc, | |
) | |
# We had an issue where the second rendezvous in a single process would use the | |
# same store keys as the first one, thus retrieve a stale address to connect to, | |
# and fail. | |
def inner_test_ipc_twice() -> None: | |
subgroup = torch.distributed.new_group() | |
init_ipc(subgroup) | |
init_ipc(subgroup) | |
def test_ipc_twice() -> None: | |
world_size = 2 | |
launch_subprocesses( | |
world_size=world_size, | |
fn=inner_test_ipc_twice, | |
) | |