NeMo / nemo /utils /distributed.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from nemo.utils import logging
try:
from apex.transformer import parallel_state
HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False
def initialize_distributed(args, backend='nccl'):
"""Initialize torch.distributed."""
# Get local rank in case it is provided.
local_rank = args.local_rank
# Get rank and world size.
rank = int(os.getenv('RANK', '0'))
world_size = int(os.getenv("WORLD_SIZE", '1'))
logging.info(
f'Initializing torch.distributed with local_rank: {local_rank}, rank: {rank}, world_size: {world_size}'
)
# Set the device id.
device = rank % torch.cuda.device_count()
if local_rank is not None:
device = local_rank
torch.cuda.set_device(device)
# Call the init process.
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(backend=backend, world_size=world_size, rank=rank, init_method=init_method)
return local_rank, rank, world_size
def gather_objects(partial_results_list, main_rank=None):
"""
Collect objects (e.g., results) from all GPUs.
Useful for inference over multiple GPUs with DDP.
Use main_rank to specify which rank will be used to gather results.
This allows to continue execution on the main_rank only after the gather.
Args:
partial_results_list: list of partial results from each GPU
main_rank: rank of the main process to collect results from all GPUs (useful for collecting results in a target rank)
Example:
predictions = gather_objects(predictions,main_rank=0)
# all but rank 0 will return None
if predictions is None:
return
# from here only rank 0 should contiue
pickle.dump(predictions, open(output_fname, "wb"))
"""
# do not fail when DDP is not initialized
if parallel_state.is_unitialized():
return partial_results_list
rank = parallel_state.get_data_parallel_rank()
world_size = parallel_state.get_data_parallel_world_size()
# return input when no DDP is used
if world_size == 1:
return partial_results_list
gathered_results = [None for _ in range(world_size)]
torch.distributed.all_gather_object(gathered_results, partial_results_list)
# return None to non-main ranks
if main_rank is not None:
if rank != main_rank:
return None
# return collected results
results_list = []
for r in gathered_results:
results_list.extend(r)
return results_list