Spaces:
Runtime error
Runtime error
File size: 7,794 Bytes
e202b16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import concurrent
import gc
import multiprocessing
import os
import signal
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
from typing import Dict, List, Tuple
import torch
class SafeMpContext(multiprocessing.context.BaseContext):
def __init__(self) -> None:
self.mp_context = multiprocessing.get_context("spawn")
self.processes: List[multiprocessing.context.SpawnProcess] = []
def Process(self, *args, **kwargs) -> multiprocessing.context.SpawnProcess:
p = self.mp_context.Process(*args, **kwargs)
p.daemon = True
self.processes.append(p)
return p
def kill_all_processes(self):
for p in self.processes:
p.terminate()
p.join(1)
# (https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Process.exitcode)
# Even though the python documentation seems to say that after joining the exitcode should
# become set, this is not what we have observed in practice. We therefore loop until it
# becomes set.
while p.exitcode is None:
p.kill()
p.join()
assert p.exitcode is not None, f"{p} is still alive"
def log_bad_exit_codes(self):
for rank, p in enumerate(self.processes):
if p.exitcode == 0:
continue
if p.exitcode < 0:
try:
signal_desc = f" (signal {signal.Signals(-p.exitcode).name})"
except ValueError:
signal_desc = " (unrecognized signal)"
else:
signal_desc = ""
print(
f"Child process for rank #{rank} with PID {p.pid} exited with code {p.exitcode}{signal_desc}"
)
def __getattr__(self, name: str):
return getattr(self.mp_context, name)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.kill_all_processes()
self.log_bad_exit_codes()
def init_process_group(init_method: str, rank: int, world_size: int):
torch._C._set_print_stack_traces_on_fatal_signal(True)
if torch.cuda.device_count() >= world_size:
backend = "nccl"
torch.cuda.set_device(rank)
else:
# Use Gloo instead of NCCL so that we can run on a single GPU
backend = "gloo"
torch.distributed.init_process_group(
backend=backend,
world_size=world_size,
rank=rank,
init_method=init_method,
)
def _launch_subprocesses_fn_wrapper(
init_method: str,
rank: int,
world_size: int,
parent_env_vars: Dict[str, str],
user_fn,
args,
kwargs,
):
# This function initializes the environment for spawned subprocesses by capturing and applying the current
# environment variables from the parent process. By clearing and then updating `os.environ` with `parent_env_vars`,
# we ensure that each spawned subprocess starts with an environment that mirrors the parent process at the time
# of job submission. This approach guarantees consistency across subprocesses, reflecting the latest state of the
# parent's environment variables even/especially when reusing the subprocesses for subsequent job executions.
os.environ.clear()
os.environ.update(parent_env_vars)
# Check if the process group is already initialized
if not torch.distributed.is_initialized():
init_process_group(init_method, rank, world_size)
try:
return user_fn(*args, **kwargs)
finally:
# should free all memory used by PyTorch in the subprocesses
gc.collect()
torch.cuda.empty_cache()
# Global dictionary to keep track of executors and temporary files
EXECUTORS_AND_FILES: Dict[
int, Tuple[_TemporaryFileWrapper, concurrent.futures.ProcessPoolExecutor]
] = {}
def get_global_pool_allocator(
world_size: int,
) -> Tuple[_TemporaryFileWrapper, concurrent.futures.ProcessPoolExecutor]:
global EXECUTORS_AND_FILES
if world_size not in EXECUTORS_AND_FILES:
rdv = NamedTemporaryFile(mode="w+b", buffering=-1, delete=False)
mp_context = SafeMpContext()
executor = concurrent.futures.ProcessPoolExecutor(
max_workers=world_size, mp_context=mp_context
)
# Add the executor and temporary file to the global list
EXECUTORS_AND_FILES[world_size] = (rdv, executor)
else:
rdv, executor = EXECUTORS_AND_FILES[world_size]
return rdv, executor
class ProcessPoolExecutorManager:
def __init__(self, world_size: int):
self.world_size = world_size
def __enter__(self):
# when you start a subprocess you want to free memory used by PyTorch in the main process,
# so the subprocess can have memory
gc.collect()
torch.cuda.empty_cache()
self.rdv, self.executor = get_global_pool_allocator(self.world_size)
return self
def submit(self, fn, *args, **kwargs):
return self.executor.submit(fn, *args, **kwargs)
def __exit__(self, exc_type, exc_val, exc_tb):
# One of the subprocesses jobs has failed
if exc_val:
# We want to avoid killing the processes while the executor was thinking that they were
# still up and healthy (as this may have unintended consequences, such as the executor
# restarting the processes, or reporting spurious errors).
# Set the internal state of the executor and call cancel() on each issued task that is
# not executing
self.executor.shutdown(wait=False, cancel_futures=True)
# Kill all remaining subprocesses
mp_context = self.executor._mp_context
mp_context.kill_all_processes()
mp_context.log_bad_exit_codes()
# We want to wait for all the futures to complete, so we need to shutdown twice
self.executor.shutdown(wait=True)
# Close the temporary file
self.rdv.close()
# Remove the executor from the global list.
# This will recreate it next time a test is requiring this world_size
assert self.world_size in EXECUTORS_AND_FILES
del EXECUTORS_AND_FILES[self.world_size]
print(
f"Shutdown and remove the executor after subprocesses error. Executors cnt: {len(EXECUTORS_AND_FILES)}"
)
def launch_subprocesses(world_size: int, fn, *args, **kwargs):
# This custom manager allows each test execution to enter/exit the following context.
# When entering the context, it creates/reuses a new/existing ProcessPoolExecutor with the given world size.
# The context also allows to detect an exception upon exit, in which case it will kill all spawned processes,
# delete the manager, recreate the manager upon following request and respawn processes.
with ProcessPoolExecutorManager(world_size) as manager:
futures = [
manager.submit(
_launch_subprocesses_fn_wrapper,
init_method=f"file://{manager.rdv.name}",
rank=rank,
world_size=world_size,
parent_env_vars=dict(os.environ),
user_fn=fn,
args=args,
kwargs=kwargs,
)
for rank in range(world_size)
]
done, _ = concurrent.futures.wait(
futures, return_when=concurrent.futures.FIRST_EXCEPTION
)
for f in done:
f.result()
|