Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,625 Bytes
08ab988 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import os
import io
from contextlib import contextmanager
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_dist(rank, local_rank, world_size, master_addr, master_port):
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = master_port
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(local_rank)
torch.cuda.set_device(local_rank)
dist.init_process_group('nccl', rank=rank, world_size=world_size)
def read_file_dist(path):
"""
Read the binary file distributedly.
File is only read once by the rank 0 process and broadcasted to other processes.
Returns:
data (io.BytesIO): The binary data read from the file.
"""
if dist.is_initialized() and dist.get_world_size() > 1:
# read file
size = torch.LongTensor(1).cuda()
if dist.get_rank() == 0:
with open(path, 'rb') as f:
data = f.read()
data = torch.ByteTensor(
torch.UntypedStorage.from_buffer(data, dtype=torch.uint8)
).cuda()
size[0] = data.shape[0]
# broadcast size
dist.broadcast(size, src=0)
if dist.get_rank() != 0:
data = torch.ByteTensor(size[0].item()).cuda()
# broadcast data
dist.broadcast(data, src=0)
# convert to io.BytesIO
data = data.cpu().numpy().tobytes()
data = io.BytesIO(data)
return data
else:
with open(path, 'rb') as f:
data = f.read()
data = io.BytesIO(data)
return data
def unwrap_dist(model):
"""
Unwrap the model from distributed training.
"""
if isinstance(model, DDP):
return model.module
return model
@contextmanager
def master_first():
"""
A context manager that ensures master process executes first.
"""
if not dist.is_initialized():
yield
else:
if dist.get_rank() == 0:
yield
dist.barrier()
else:
dist.barrier()
yield
@contextmanager
def local_master_first():
"""
A context manager that ensures local master process executes first.
"""
if not dist.is_initialized():
yield
else:
if dist.get_rank() % torch.cuda.device_count() == 0:
yield
dist.barrier()
else:
dist.barrier()
yield
|