# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import concurrent import gc import multiprocessing import os import signal from tempfile import NamedTemporaryFile, _TemporaryFileWrapper from typing import Dict, List, Tuple import torch class SafeMpContext(multiprocessing.context.BaseContext): def __init__(self) -> None: self.mp_context = multiprocessing.get_context("spawn") self.processes: List[multiprocessing.context.SpawnProcess] = [] def Process(self, *args, **kwargs) -> multiprocessing.context.SpawnProcess: p = self.mp_context.Process(*args, **kwargs) p.daemon = True self.processes.append(p) return p def kill_all_processes(self): for p in self.processes: p.terminate() p.join(1) # (https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Process.exitcode) # Even though the python documentation seems to say that after joining the exitcode should # become set, this is not what we have observed in practice. We therefore loop until it # becomes set. while p.exitcode is None: p.kill() p.join() assert p.exitcode is not None, f"{p} is still alive" def log_bad_exit_codes(self): for rank, p in enumerate(self.processes): if p.exitcode == 0: continue if p.exitcode < 0: try: signal_desc = f" (signal {signal.Signals(-p.exitcode).name})" except ValueError: signal_desc = " (unrecognized signal)" else: signal_desc = "" print( f"Child process for rank #{rank} with PID {p.pid} exited with code {p.exitcode}{signal_desc}" ) def __getattr__(self, name: str): return getattr(self.mp_context, name) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.kill_all_processes() self.log_bad_exit_codes() def init_process_group(init_method: str, rank: int, world_size: int): torch._C._set_print_stack_traces_on_fatal_signal(True) if torch.cuda.device_count() >= world_size: backend = "nccl" torch.cuda.set_device(rank) else: # Use Gloo instead of NCCL so that we can run on a single GPU backend = "gloo" torch.distributed.init_process_group( backend=backend, world_size=world_size, rank=rank, init_method=init_method, ) def _launch_subprocesses_fn_wrapper( init_method: str, rank: int, world_size: int, parent_env_vars: Dict[str, str], user_fn, args, kwargs, ): # This function initializes the environment for spawned subprocesses by capturing and applying the current # environment variables from the parent process. By clearing and then updating `os.environ` with `parent_env_vars`, # we ensure that each spawned subprocess starts with an environment that mirrors the parent process at the time # of job submission. This approach guarantees consistency across subprocesses, reflecting the latest state of the # parent's environment variables even/especially when reusing the subprocesses for subsequent job executions. os.environ.clear() os.environ.update(parent_env_vars) # Check if the process group is already initialized if not torch.distributed.is_initialized(): init_process_group(init_method, rank, world_size) try: return user_fn(*args, **kwargs) finally: # should free all memory used by PyTorch in the subprocesses gc.collect() torch.cuda.empty_cache() # Global dictionary to keep track of executors and temporary files EXECUTORS_AND_FILES: Dict[ int, Tuple[_TemporaryFileWrapper, concurrent.futures.ProcessPoolExecutor] ] = {} def get_global_pool_allocator( world_size: int, ) -> Tuple[_TemporaryFileWrapper, concurrent.futures.ProcessPoolExecutor]: global EXECUTORS_AND_FILES if world_size not in EXECUTORS_AND_FILES: rdv = NamedTemporaryFile(mode="w+b", buffering=-1, delete=False) mp_context = SafeMpContext() executor = concurrent.futures.ProcessPoolExecutor( max_workers=world_size, mp_context=mp_context ) # Add the executor and temporary file to the global list EXECUTORS_AND_FILES[world_size] = (rdv, executor) else: rdv, executor = EXECUTORS_AND_FILES[world_size] return rdv, executor class ProcessPoolExecutorManager: def __init__(self, world_size: int): self.world_size = world_size def __enter__(self): # when you start a subprocess you want to free memory used by PyTorch in the main process, # so the subprocess can have memory gc.collect() torch.cuda.empty_cache() self.rdv, self.executor = get_global_pool_allocator(self.world_size) return self def submit(self, fn, *args, **kwargs): return self.executor.submit(fn, *args, **kwargs) def __exit__(self, exc_type, exc_val, exc_tb): # One of the subprocesses jobs has failed if exc_val: # We want to avoid killing the processes while the executor was thinking that they were # still up and healthy (as this may have unintended consequences, such as the executor # restarting the processes, or reporting spurious errors). # Set the internal state of the executor and call cancel() on each issued task that is # not executing self.executor.shutdown(wait=False, cancel_futures=True) # Kill all remaining subprocesses mp_context = self.executor._mp_context mp_context.kill_all_processes() mp_context.log_bad_exit_codes() # We want to wait for all the futures to complete, so we need to shutdown twice self.executor.shutdown(wait=True) # Close the temporary file self.rdv.close() # Remove the executor from the global list. # This will recreate it next time a test is requiring this world_size assert self.world_size in EXECUTORS_AND_FILES del EXECUTORS_AND_FILES[self.world_size] print( f"Shutdown and remove the executor after subprocesses error. Executors cnt: {len(EXECUTORS_AND_FILES)}" ) def launch_subprocesses(world_size: int, fn, *args, **kwargs): # This custom manager allows each test execution to enter/exit the following context. # When entering the context, it creates/reuses a new/existing ProcessPoolExecutor with the given world size. # The context also allows to detect an exception upon exit, in which case it will kill all spawned processes, # delete the manager, recreate the manager upon following request and respawn processes. with ProcessPoolExecutorManager(world_size) as manager: futures = [ manager.submit( _launch_subprocesses_fn_wrapper, init_method=f"file://{manager.rdv.name}", rank=rank, world_size=world_size, parent_env_vars=dict(os.environ), user_fn=fn, args=args, kwargs=kwargs, ) for rank in range(world_size) ] done, _ = concurrent.futures.wait( futures, return_when=concurrent.futures.FIRST_EXCEPTION ) for f in done: f.result()