File size: 7,794 Bytes
e202b16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import concurrent
import gc
import multiprocessing
import os
import signal
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
from typing import Dict, List, Tuple

import torch


class SafeMpContext(multiprocessing.context.BaseContext):
    def __init__(self) -> None:
        self.mp_context = multiprocessing.get_context("spawn")
        self.processes: List[multiprocessing.context.SpawnProcess] = []

    def Process(self, *args, **kwargs) -> multiprocessing.context.SpawnProcess:
        p = self.mp_context.Process(*args, **kwargs)
        p.daemon = True
        self.processes.append(p)
        return p

    def kill_all_processes(self):
        for p in self.processes:
            p.terminate()
            p.join(1)

            # (https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Process.exitcode)
            # Even though the python documentation seems to say that after joining the exitcode should
            # become set, this is not what we have observed in practice. We therefore loop until it
            # becomes set.
            while p.exitcode is None:
                p.kill()
                p.join()

            assert p.exitcode is not None, f"{p} is still alive"

    def log_bad_exit_codes(self):
        for rank, p in enumerate(self.processes):
            if p.exitcode == 0:
                continue
            if p.exitcode < 0:
                try:
                    signal_desc = f" (signal {signal.Signals(-p.exitcode).name})"
                except ValueError:
                    signal_desc = " (unrecognized signal)"
            else:
                signal_desc = ""
            print(
                f"Child process for rank #{rank} with PID {p.pid} exited with code {p.exitcode}{signal_desc}"
            )

    def __getattr__(self, name: str):
        return getattr(self.mp_context, name)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.kill_all_processes()
        self.log_bad_exit_codes()


def init_process_group(init_method: str, rank: int, world_size: int):
    torch._C._set_print_stack_traces_on_fatal_signal(True)

    if torch.cuda.device_count() >= world_size:
        backend = "nccl"
        torch.cuda.set_device(rank)
    else:
        # Use Gloo instead of NCCL so that we can run on a single GPU
        backend = "gloo"

    torch.distributed.init_process_group(
        backend=backend,
        world_size=world_size,
        rank=rank,
        init_method=init_method,
    )


def _launch_subprocesses_fn_wrapper(
    init_method: str,
    rank: int,
    world_size: int,
    parent_env_vars: Dict[str, str],
    user_fn,
    args,
    kwargs,
):
    # This function initializes the environment for spawned subprocesses by capturing and applying the current
    # environment variables from the parent process. By clearing and then updating `os.environ` with `parent_env_vars`,
    # we ensure that each spawned subprocess starts with an environment that mirrors the parent process at the time
    # of job submission. This approach guarantees consistency across subprocesses, reflecting the latest state of the
    # parent's environment variables even/especially when reusing the subprocesses for subsequent job executions.
    os.environ.clear()
    os.environ.update(parent_env_vars)

    # Check if the process group is already initialized
    if not torch.distributed.is_initialized():
        init_process_group(init_method, rank, world_size)
    try:
        return user_fn(*args, **kwargs)
    finally:
        # should free all memory used by PyTorch in the subprocesses
        gc.collect()
        torch.cuda.empty_cache()


# Global dictionary to keep track of executors and temporary files
EXECUTORS_AND_FILES: Dict[
    int, Tuple[_TemporaryFileWrapper, concurrent.futures.ProcessPoolExecutor]
] = {}


def get_global_pool_allocator(
    world_size: int,
) -> Tuple[_TemporaryFileWrapper, concurrent.futures.ProcessPoolExecutor]:
    global EXECUTORS_AND_FILES

    if world_size not in EXECUTORS_AND_FILES:
        rdv = NamedTemporaryFile(mode="w+b", buffering=-1, delete=False)
        mp_context = SafeMpContext()

        executor = concurrent.futures.ProcessPoolExecutor(
            max_workers=world_size, mp_context=mp_context
        )

        # Add the executor and temporary file to the global list
        EXECUTORS_AND_FILES[world_size] = (rdv, executor)
    else:
        rdv, executor = EXECUTORS_AND_FILES[world_size]

    return rdv, executor


class ProcessPoolExecutorManager:
    def __init__(self, world_size: int):
        self.world_size = world_size

    def __enter__(self):
        # when you start a subprocess you want to free memory used by PyTorch in the main process,
        # so the subprocess can have memory
        gc.collect()
        torch.cuda.empty_cache()

        self.rdv, self.executor = get_global_pool_allocator(self.world_size)
        return self

    def submit(self, fn, *args, **kwargs):
        return self.executor.submit(fn, *args, **kwargs)

    def __exit__(self, exc_type, exc_val, exc_tb):
        # One of the subprocesses jobs has failed
        if exc_val:
            # We want to avoid killing the processes while the executor was thinking that they were
            # still up and healthy (as this may have unintended consequences, such as the executor
            # restarting the processes, or reporting spurious errors).
            # Set the internal state of the executor and call cancel() on each issued task that is
            # not executing
            self.executor.shutdown(wait=False, cancel_futures=True)

            # Kill all remaining subprocesses
            mp_context = self.executor._mp_context
            mp_context.kill_all_processes()
            mp_context.log_bad_exit_codes()

            # We want to wait for all the futures to complete, so we need to shutdown twice
            self.executor.shutdown(wait=True)

            # Close the temporary file
            self.rdv.close()

            # Remove the executor from the global list.
            # This will recreate it next time a test is requiring this world_size
            assert self.world_size in EXECUTORS_AND_FILES
            del EXECUTORS_AND_FILES[self.world_size]

            print(
                f"Shutdown and remove the executor after subprocesses error. Executors cnt: {len(EXECUTORS_AND_FILES)}"
            )


def launch_subprocesses(world_size: int, fn, *args, **kwargs):
    # This custom manager allows each test execution to enter/exit the following context.
    # When entering the context, it creates/reuses a new/existing ProcessPoolExecutor with the given world size.
    # The context also allows to detect an exception upon exit, in which case it will kill all spawned processes,
    # delete the manager, recreate the manager upon following request and respawn processes.
    with ProcessPoolExecutorManager(world_size) as manager:
        futures = [
            manager.submit(
                _launch_subprocesses_fn_wrapper,
                init_method=f"file://{manager.rdv.name}",
                rank=rank,
                world_size=world_size,
                parent_env_vars=dict(os.environ),
                user_fn=fn,
                args=args,
                kwargs=kwargs,
            )
            for rank in range(world_size)
        ]
        done, _ = concurrent.futures.wait(
            futures, return_when=concurrent.futures.FIRST_EXCEPTION
        )
        for f in done:
            f.result()