Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import functools | |
import os | |
from typing import Optional | |
import torch | |
from torch.distributed import ProcessGroup | |
_GROUP: Optional[ProcessGroup] = None | |
_WORLD_SIZE: Optional[int] = None | |
_LOCAL_RANK: int = 0 | |
def initialize( | |
world_size: int, | |
local_rank: int, | |
group: Optional[ProcessGroup] = None, | |
use_gpu: bool = True, | |
seed: int = 80486, | |
) -> str: | |
""" | |
Initialize model parallelism support. | |
Args: | |
world_size (int): the number of processes running on | |
the current node available for model parallelism. | |
local_rank (int): the present process' rank. | |
group (torch.distributed.ProcessGroup, optional): the | |
process group to use for model parallel communications. | |
use_gpu (bool, optional): whether computations are | |
happening on a GPU or not (defaults to True). | |
seed (int, optional): the seed used to seed the prng | |
on all model parallel processes | |
Returns | |
The pytorch device to use in the present process. | |
Note: | |
If ``group`` is not specified, the default process group is | |
used for model parallelism. This means that the present | |
module may be incompatible with other forms of parallelism | |
such as data parallelism. | |
""" | |
global _GROUP | |
global _WORLD_SIZE | |
global _LOCAL_RANK | |
assert local_rank < world_size | |
if use_gpu: | |
device = f"cuda:{local_rank}" | |
torch.cuda.set_device(local_rank) | |
else: | |
device = "cpu" | |
if group is None: | |
if "MASTER_ADDR" not in os.environ: | |
assert world_size == 1 | |
os.environ["MASTER_ADDR"] = "127.0.0.1" | |
os.environ["MASTER_PORT"] = "1234" | |
torch.distributed.init_process_group( | |
backend="nccl" if use_gpu else "gloo", | |
init_method="env://", | |
world_size=world_size, | |
rank=local_rank, | |
) | |
_GROUP = group | |
_WORLD_SIZE = world_size | |
_LOCAL_RANK = local_rank | |
torch.manual_seed(seed) | |
return device | |
def get_world_size() -> int: | |
if _WORLD_SIZE is None: | |
raise RuntimeError("model parallelism was not initialized") | |
return _WORLD_SIZE | |
def get_rank() -> int: | |
if _WORLD_SIZE is None: | |
raise RuntimeError("model parallelism was not initialized") | |
return _LOCAL_RANK | |
def all_gather(x: torch.Tensor) -> torch.Tensor: | |
""" | |
Gather a tensor of shape (n, m) into a tensor of shape (n, mp_size * m). | |
""" | |
mp_size = get_world_size() | |
if mp_size == 1: | |
return x | |
gather = [torch.empty_like(x) for _ in range(mp_size)] | |
torch.distributed.all_gather(gather, x, group=_GROUP) | |
return torch.cat(gather, dim=-1) | |
def all_reduce(x: torch.Tensor): | |
if get_world_size() > 1: | |
# reduce with a sum | |
torch.distributed.all_reduce(x, group=_GROUP) | |