Spaces:
Runtime error
Runtime error
File size: 3,269 Bytes
f670afc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import functools
import ctypes
import torch
import torch.distributed as dist
def init_dist(local_rank, backend='nccl', **kwargs):
r"""Initialize distributed training"""
if dist.is_available():
if dist.is_initialized():
return torch.cuda.current_device()
torch.cuda.set_device(local_rank)
dist.init_process_group(backend=backend, init_method='env://', **kwargs)
# Increase the L2 fetch granularity for faster speed.
_libcudart = ctypes.CDLL('libcudart.so')
# Set device limit on the current device
# cudaLimitMaxL2FetchGranularity = 0x05
pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
_libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
# assert pValue.contents.value == 128
def get_rank():
r"""Get rank of the thread."""
rank = 0
if dist.is_available():
if dist.is_initialized():
rank = dist.get_rank()
return rank
def get_world_size():
r"""Get world size. How many GPUs are available in this job."""
world_size = 1
if dist.is_available():
if dist.is_initialized():
world_size = dist.get_world_size()
return world_size
def master_only(func):
r"""Apply this function only to the master GPU."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
r"""Simple function wrapper for the master function"""
if get_rank() == 0:
return func(*args, **kwargs)
else:
return None
return wrapper
def is_master():
r"""check if current process is the master"""
return get_rank() == 0
def is_local_master():
return torch.cuda.current_device() == 0
@master_only
def master_only_print(*args):
r"""master-only print"""
print(*args)
def dist_reduce_tensor(tensor, rank=0, reduce='mean'):
r""" Reduce to rank 0 """
world_size = get_world_size()
if world_size < 2:
return tensor
with torch.no_grad():
dist.reduce(tensor, dst=rank)
if get_rank() == rank:
if reduce == 'mean':
tensor /= world_size
elif reduce == 'sum':
pass
else:
raise NotImplementedError
return tensor
def dist_all_reduce_tensor(tensor, reduce='mean'):
r""" Reduce to all ranks """
world_size = get_world_size()
if world_size < 2:
return tensor
with torch.no_grad():
dist.all_reduce(tensor)
if reduce == 'mean':
tensor /= world_size
elif reduce == 'sum':
pass
else:
raise NotImplementedError
return tensor
def dist_all_gather_tensor(tensor):
r""" gather to all ranks """
world_size = get_world_size()
if world_size < 2:
return [tensor]
tensor_list = [
torch.ones_like(tensor) for _ in range(dist.get_world_size())]
with torch.no_grad():
dist.all_gather(tensor_list, tensor)
return tensor_list
|