File size: 12,969 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import itertools

import torch
from apex.contrib.optimizers.distributed_fused_adam import (
    DistributedFusedAdam,
    _coalescing_manager,
    _disable_pre_forward_hook,
)
from apex.transformer import parallel_state


def _str_to_dtype(dtype):
    if isinstance(dtype, torch.dtype):
        return dtype
    name = str(dtype).strip().lower()
    if name in ('', 'none'):
        return torch.float32
    elif name in ('torch.float32', 'float32', 'float', 'fp32', '32'):
        return torch.float32
    elif name in ('torch.float16', 'float16', 'half', 'fp16', '16'):
        return torch.float16
    elif name in ('torch.bfloat16', 'bfloat16', 'bf16'):
        return torch.bfloat16
    else:
        raise ValueError(f'unsupported dtype ({dtype})')


class MegatronDistributedFusedAdam(DistributedFusedAdam):
    """Wrapper class that supports NeMo-Megatron optimizations

    When O2-style optimizations are enabled, gradients are accumulated
    into the main_grad buffer instead of the grad buffer.

    """

    def __init__(self, params, disable_distributed_parameters=False, **kwargs):

        # Initialize process groups
        if 'process_group' not in kwargs and not parallel_state.is_unitialized():
            kwargs['process_group'] = parallel_state.get_data_parallel_group()
        if disable_distributed_parameters:
            world_size = torch.distributed.get_world_size()
            rank = torch.distributed.get_rank()
            self_groups = [torch.distributed.new_group(ranks=[i]) for i in range(world_size)]
            kwargs['distributed_process_group'] = self_groups[rank]
            kwargs['redundant_process_group'] = kwargs['process_group']

        # Make sure dtypes are in right type
        for keyword in ('dtype', 'grad_sync_dtype', 'param_sync_dtype'):
            if keyword in kwargs:
                kwargs[keyword] = _str_to_dtype(kwargs[keyword])

        # Make sure params are in consistent format (list of param group dicts)
        param_groups = list(params)
        assert param_groups
        if not isinstance(param_groups[0], dict):
            param_groups = [{'params': param_groups}]

        # Check if explicit FP32 optimizer is needed
        self._fp32_optim = None
        distopt_param_groups = param_groups
        dtype = kwargs['dtype'] if 'dtype' in kwargs else torch.float32
        grad_sync_dtype = kwargs['grad_sync_dtype'] if 'grad_sync_dtype' in kwargs else dtype
        needs_fp32_optimizer = any(
            getattr(param, '_with_fp32_optimizer', False)
            for param in itertools.chain.from_iterable(param_group['params'] for param_group in param_groups)
        )
        if (dtype != torch.float32 or grad_sync_dtype != torch.float32) and needs_fp32_optimizer:

            # Find params that require explicit FP32 optimizer
            distopt_param_groups = []
            fp32_param_groups = []
            self._fp32_optim_main_params = collections.OrderedDict()
            for param_group in param_groups:
                distopt_param_group = {key: val for key, val in param_group.items() if key != 'params'}
                distopt_param_group['params'] = []
                fp32_param_group = {key: val for key, val in param_group.items() if key != 'params'}
                fp32_param_group['params'] = []
                for model_param in param_group['params']:
                    if getattr(model_param, '_with_fp32_optimizer', False):
                        main_param = model_param.detach().clone().float()
                        fp32_param_group['params'].append(main_param)
                        self._fp32_optim_main_params[model_param] = main_param
                    else:
                        distopt_param_group['params'].append(model_param)
                distopt_param_groups.append(distopt_param_group)
                fp32_param_groups.append(fp32_param_group)

            # Construct explicit FP32 optimizer
            adamw_kwargs = {}
            for name in ('lr', 'betas', 'eps', 'weight_decay', 'amsgrad'):
                if name in kwargs:
                    adamw_kwargs[name] = kwargs[name]
            self._fp32_optim = torch.optim.AdamW(fp32_param_groups, **adamw_kwargs)
            self._fp32_optim_grad_sync_needed = True

        # Construct distributed optimizer
        super().__init__(distopt_param_groups, **kwargs)

    def _make_post_backward_hook(self, param, param_group_id, param_id):
        def hook(*unused):
            if getattr(param, '_pre_forward_hook_is_enabled', False):
                raise RuntimeError(
                    'A parameter called its post-backward hook '
                    'before its pre-forward hook. '
                    'Please manually interact with the parameter '
                    'before the forward pass (e.g. by calling data_ptr) '
                    'or run DistributedFusedAdam with overlap_param_sync=False.'
                )
            with self._lock:
                need_to_initialize = 'fragments' not in self.state[param]
                if need_to_initialize:
                    self._init_param_state(param, param_group_id, param_id)
                if self.greedy_grad_copy and not getattr(param, '_disable_greedy_grad_copy', False):
                    self._grad_copy(param)
                    if self.overlap_grad_sync and not getattr(param, '_disable_overlap_grad_sync', False):
                        self._try_start_bucket_grad_sync(
                            params=[param], ignore_last_bucket=need_to_initialize,
                        )

        return hook

    def _filter_distopt_params(self, params):
        if self._fp32_optim is None:
            return params
        if params is None:
            return None
        if isinstance(params, torch.Tensor):
            params = [params]
        return filter(lambda param: param not in self._fp32_optim_main_params, params)

    def parameters(self, with_fp32_optim_params=False):
        if with_fp32_optim_params and self._fp32_optim is not None:
            return itertools.chain(super().parameters(), self._fp32_optim_main_params.keys())
        else:
            return super().parameters()

    def init_params(self, params=None):
        super().init_params(self._filter_distopt_params(params))

    def init_params_bucket(self, params):
        super().init_params_bucket(self._filter_distopt_params(params))

    def try_grad_sync(self, params):
        params = self._filter_distopt_params(params)
        params = [p for p in params if not getattr(p, '_disable_greedy_grad_copy', False)]
        params = [p for p in params if not getattr(p, '_disable_overlap_grad_sync', False)]
        for p in params:
            self._grad_copy(p)
        self._try_start_bucket_grad_sync(params=params)

    def _try_start_bucket_param_sync(self, params=None):
        super()._try_start_bucket_param_sync(self._filter_distopt_params(params))

    def _fp32_optim_grad_sync(self):
        if self._fp32_optim is None or not self._fp32_optim_grad_sync_needed:
            return
        for model_param, main_param in self._fp32_optim_main_params.items():
            if model_param.grad is not None:
                main_param.grad += model_param.grad.detach()
        sync_requests = []
        with _coalescing_manager(self.process_group, self.device, sync_requests):
            for main_param in self._fp32_optim_main_params.values():
                sync_requests.append(
                    torch.distributed.all_reduce(
                        main_param.grad, op=torch.distributed.ReduceOp.AVG, group=self.process_group, async_op=True,
                    )
                )
        for req in sync_requests:
            req.wait()
        self._fp32_optim_grad_sync_needed = False

    def zero_grad(self, *args, **kwargs):
        super().zero_grad(*args, **kwargs)

        # Reset grads for explicit FP32 optimizer
        if self._fp32_optim is not None:
            self._fp32_optim_grad_sync_needed = True
            self._fp32_optim.zero_grad(set_to_none=False)
            for model_param, main_param in self._fp32_optim_main_params.items():
                if main_param.grad is None:
                    main_param.grad = torch.zeros_like(main_param)
                if model_param.grad is not None:
                    model_param.grad.zero_()
                model_param.main_grad = main_param.grad

        # Reset main grads
        if self.contiguous_grad_buffer:
            for param in self.parameters():
                with _disable_pre_forward_hook(param):
                    param.main_grad = self.grad_buffer_view(param)

    def grad_norm(self, parameters=None, norm_type=2.0, force=False):
        assert norm_type == 2

        if parameters is not None:
            # Make sure we can access iterable multiple times
            parameters = list(parameters)

        # Compute grad norm
        if force or self._grad_norm is None:

            # Compute norm of local gradients for distributed optimizer
            grad_norm_sq = self._local_grad_norm(
                parameters=self._filter_distopt_params(parameters), norm_type=norm_type,
            )
            if self.redundant_size > 1:
                grad_norm_sq /= self.redundant_size

            # Compute norm of local gradients for explicit FP32 optimizer
            if self._fp32_optim is not None:
                self._fp32_optim_grad_sync()
                if parameters is None:
                    for main_param in self._fp32_optim_main_params.values():
                        grad_norm_sq += torch.linalg.norm(main_param.grad) ** 2 / self.process_group_size
                else:
                    for model_param in parameters:
                        if model_param in self._fp32_optim_main_params:
                            main_param = self._fp32_optim_main_params[model_param]
                            grad_norm_sq += torch.linalg.norm(main_param.grad) ** 2 / self.process_group_size

            # Sum over all procs to get grad norm
            torch.distributed.all_reduce(
                grad_norm_sq, op=torch.distributed.ReduceOp.SUM,
            )
            self._grad_norm = grad_norm_sq.sqrt()

        # Use cached grad norm
        return super().grad_norm()

    def step(self, closure=None, *, grad_scaler=None):

        # Apply distributed optimizer
        loss = super().step(closure=closure, grad_scaler=grad_scaler)

        if self._fp32_optim is not None:

            # Handle grad scaling
            if grad_scaler is not None:
                scaler_state = grad_scaler._per_optimizer_states[id(self)]
                for _, found_inf in scaler_state['found_inf_per_device'].items():
                    if found_inf.item():
                        return loss

            # Update learning rate
            for distopt_group, fp32_optim_group in zip(self.param_groups, self._fp32_optim.param_groups):
                fp32_optim_group['lr'] = distopt_group['lr']

            # Apply explicit FP32 optimizer
            self._fp32_optim_grad_sync()
            for main_param in self._fp32_optim_main_params.values():
                main_param.grad *= self._grad_scale
            self._fp32_optim.step()
            for model_param, main_param in self._fp32_optim_main_params.items():
                model_param.detach().copy_(main_param.detach())

        return loss

    def state_dict(self, *args, **kwargs):
        state_dict = super().state_dict(*args, **kwargs)
        if self._fp32_optim is not None and state_dict is not None:
            state_dict['fp32_optim'] = self._fp32_optim.state_dict()
            state_dict['fp32_optim_fp32_params'] = list(self._fp32_optim_main_params.values())
        return state_dict

    def load_state_dict(self, state_dict):
        if self._fp32_optim is not None and 'fp32_optim' in state_dict:
            self._fp32_optim.load_state_dict(state_dict['fp32_optim'])
            del state_dict['fp32_optim']
            for old_main_param, new_main_param in zip(
                self._fp32_optim_main_params.values(), state_dict['fp32_optim_fp32_params']
            ):
                old_main_param.copy_(new_main_param.detach())
            del state_dict['fp32_optim_fp32_params']
        return super().load_state_dict(state_dict)