|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
import torch |
|
|
|
from nemo.utils import logging |
|
|
|
try: |
|
from apex.transformer import parallel_state |
|
|
|
HAVE_APEX = True |
|
except (ImportError, ModuleNotFoundError): |
|
HAVE_APEX = False |
|
|
|
|
|
def initialize_distributed(args, backend='nccl'): |
|
"""Initialize torch.distributed.""" |
|
|
|
local_rank = args.local_rank |
|
|
|
|
|
rank = int(os.getenv('RANK', '0')) |
|
world_size = int(os.getenv("WORLD_SIZE", '1')) |
|
|
|
logging.info( |
|
f'Initializing torch.distributed with local_rank: {local_rank}, rank: {rank}, world_size: {world_size}' |
|
) |
|
|
|
|
|
device = rank % torch.cuda.device_count() |
|
if local_rank is not None: |
|
device = local_rank |
|
torch.cuda.set_device(device) |
|
|
|
|
|
init_method = 'tcp://' |
|
master_ip = os.getenv('MASTER_ADDR', 'localhost') |
|
master_port = os.getenv('MASTER_PORT', '6000') |
|
init_method += master_ip + ':' + master_port |
|
torch.distributed.init_process_group(backend=backend, world_size=world_size, rank=rank, init_method=init_method) |
|
return local_rank, rank, world_size |
|
|
|
|
|
def gather_objects(partial_results_list, main_rank=None): |
|
""" |
|
Collect objects (e.g., results) from all GPUs. |
|
Useful for inference over multiple GPUs with DDP. |
|
|
|
Use main_rank to specify which rank will be used to gather results. |
|
This allows to continue execution on the main_rank only after the gather. |
|
|
|
Args: |
|
partial_results_list: list of partial results from each GPU |
|
main_rank: rank of the main process to collect results from all GPUs (useful for collecting results in a target rank) |
|
|
|
|
|
Example: |
|
predictions = gather_objects(predictions,main_rank=0) |
|
# all but rank 0 will return None |
|
if predictions is None: |
|
return |
|
|
|
# from here only rank 0 should contiue |
|
pickle.dump(predictions, open(output_fname, "wb")) |
|
""" |
|
|
|
if parallel_state.is_unitialized(): |
|
return partial_results_list |
|
|
|
rank = parallel_state.get_data_parallel_rank() |
|
world_size = parallel_state.get_data_parallel_world_size() |
|
|
|
if world_size == 1: |
|
return partial_results_list |
|
|
|
gathered_results = [None for _ in range(world_size)] |
|
torch.distributed.all_gather_object(gathered_results, partial_results_list) |
|
|
|
|
|
if main_rank is not None: |
|
if rank != main_rank: |
|
return None |
|
|
|
|
|
results_list = [] |
|
for r in gathered_results: |
|
results_list.extend(r) |
|
|
|
return results_list |
|
|