Spaces:
Sleeping
Sleeping
Dit-document-layout-analysis
/
unilm
/decoding
/IAD
/fairseq
/tests
/distributed
/test_distributed_utils.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import functools | |
import sys | |
import unittest | |
import torch | |
from fairseq import distributed_utils as dist_utils | |
from .utils import objects_are_equal, spawn_and_init | |
class TestDistributedUtils(unittest.TestCase): | |
def setUp(self): | |
if not torch.cuda.is_available(): | |
raise unittest.SkipTest("CUDA not available, skipping test") | |
if sys.platform == "win32": | |
raise unittest.SkipTest("NCCL doesn't support Windows, skipping test") | |
if torch.cuda.device_count() < 2: | |
raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") | |
def test_broadcast_object_python(self): | |
spawn_and_init( | |
functools.partial( | |
TestDistributedUtils._test_broadcast_object, | |
"hello world", | |
), | |
world_size=2, | |
) | |
def test_broadcast_object_tensor(self): | |
spawn_and_init( | |
functools.partial( | |
TestDistributedUtils._test_broadcast_object, | |
torch.rand(5), | |
), | |
world_size=2, | |
) | |
def test_broadcast_object_complex(self): | |
spawn_and_init( | |
functools.partial( | |
TestDistributedUtils._test_broadcast_object, | |
{ | |
"a": "1", | |
"b": [2, torch.rand(2, 3), 3], | |
"c": (torch.rand(2, 3), 4), | |
"d": {5, torch.rand(5)}, | |
"e": torch.rand(5), | |
"f": torch.rand(5).int().cuda(), | |
}, | |
), | |
world_size=2, | |
) | |
def _test_broadcast_object(ref_obj, rank, group): | |
obj = dist_utils.broadcast_object( | |
ref_obj if rank == 0 else None, src_rank=0, group=group | |
) | |
assert objects_are_equal(ref_obj, obj) | |
if __name__ == "__main__": | |
unittest.main() | |