Spaces:
Runtime error
Runtime error
File size: 3,084 Bytes
e202b16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
import os
from typing import Optional
import torch
from torch.distributed import ProcessGroup
_GROUP: Optional[ProcessGroup] = None
_WORLD_SIZE: Optional[int] = None
_LOCAL_RANK: int = 0
def initialize(
world_size: int,
local_rank: int,
group: Optional[ProcessGroup] = None,
use_gpu: bool = True,
seed: int = 80486,
) -> str:
"""
Initialize model parallelism support.
Args:
world_size (int): the number of processes running on
the current node available for model parallelism.
local_rank (int): the present process' rank.
group (torch.distributed.ProcessGroup, optional): the
process group to use for model parallel communications.
use_gpu (bool, optional): whether computations are
happening on a GPU or not (defaults to True).
seed (int, optional): the seed used to seed the prng
on all model parallel processes
Returns
The pytorch device to use in the present process.
Note:
If ``group`` is not specified, the default process group is
used for model parallelism. This means that the present
module may be incompatible with other forms of parallelism
such as data parallelism.
"""
global _GROUP
global _WORLD_SIZE
global _LOCAL_RANK
assert local_rank < world_size
if use_gpu:
device = f"cuda:{local_rank}"
torch.cuda.set_device(local_rank)
else:
device = "cpu"
if group is None:
if "MASTER_ADDR" not in os.environ:
assert world_size == 1
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "1234"
torch.distributed.init_process_group(
backend="nccl" if use_gpu else "gloo",
init_method="env://",
world_size=world_size,
rank=local_rank,
)
_GROUP = group
_WORLD_SIZE = world_size
_LOCAL_RANK = local_rank
torch.manual_seed(seed)
return device
@functools.cache
def get_world_size() -> int:
if _WORLD_SIZE is None:
raise RuntimeError("model parallelism was not initialized")
return _WORLD_SIZE
@functools.cache
def get_rank() -> int:
if _WORLD_SIZE is None:
raise RuntimeError("model parallelism was not initialized")
return _LOCAL_RANK
def all_gather(x: torch.Tensor) -> torch.Tensor:
"""
Gather a tensor of shape (n, m) into a tensor of shape (n, mp_size * m).
"""
mp_size = get_world_size()
if mp_size == 1:
return x
gather = [torch.empty_like(x) for _ in range(mp_size)]
torch.distributed.all_gather(gather, x, group=_GROUP)
return torch.cat(gather, dim=-1)
def all_reduce(x: torch.Tensor):
if get_world_size() > 1:
# reduce with a sum
torch.distributed.all_reduce(x, group=_GROUP)
|