File size: 12,969 Bytes
7934b29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 |
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import itertools
import torch
from apex.contrib.optimizers.distributed_fused_adam import (
DistributedFusedAdam,
_coalescing_manager,
_disable_pre_forward_hook,
)
from apex.transformer import parallel_state
def _str_to_dtype(dtype):
if isinstance(dtype, torch.dtype):
return dtype
name = str(dtype).strip().lower()
if name in ('', 'none'):
return torch.float32
elif name in ('torch.float32', 'float32', 'float', 'fp32', '32'):
return torch.float32
elif name in ('torch.float16', 'float16', 'half', 'fp16', '16'):
return torch.float16
elif name in ('torch.bfloat16', 'bfloat16', 'bf16'):
return torch.bfloat16
else:
raise ValueError(f'unsupported dtype ({dtype})')
class MegatronDistributedFusedAdam(DistributedFusedAdam):
"""Wrapper class that supports NeMo-Megatron optimizations
When O2-style optimizations are enabled, gradients are accumulated
into the main_grad buffer instead of the grad buffer.
"""
def __init__(self, params, disable_distributed_parameters=False, **kwargs):
# Initialize process groups
if 'process_group' not in kwargs and not parallel_state.is_unitialized():
kwargs['process_group'] = parallel_state.get_data_parallel_group()
if disable_distributed_parameters:
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
self_groups = [torch.distributed.new_group(ranks=[i]) for i in range(world_size)]
kwargs['distributed_process_group'] = self_groups[rank]
kwargs['redundant_process_group'] = kwargs['process_group']
# Make sure dtypes are in right type
for keyword in ('dtype', 'grad_sync_dtype', 'param_sync_dtype'):
if keyword in kwargs:
kwargs[keyword] = _str_to_dtype(kwargs[keyword])
# Make sure params are in consistent format (list of param group dicts)
param_groups = list(params)
assert param_groups
if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}]
# Check if explicit FP32 optimizer is needed
self._fp32_optim = None
distopt_param_groups = param_groups
dtype = kwargs['dtype'] if 'dtype' in kwargs else torch.float32
grad_sync_dtype = kwargs['grad_sync_dtype'] if 'grad_sync_dtype' in kwargs else dtype
needs_fp32_optimizer = any(
getattr(param, '_with_fp32_optimizer', False)
for param in itertools.chain.from_iterable(param_group['params'] for param_group in param_groups)
)
if (dtype != torch.float32 or grad_sync_dtype != torch.float32) and needs_fp32_optimizer:
# Find params that require explicit FP32 optimizer
distopt_param_groups = []
fp32_param_groups = []
self._fp32_optim_main_params = collections.OrderedDict()
for param_group in param_groups:
distopt_param_group = {key: val for key, val in param_group.items() if key != 'params'}
distopt_param_group['params'] = []
fp32_param_group = {key: val for key, val in param_group.items() if key != 'params'}
fp32_param_group['params'] = []
for model_param in param_group['params']:
if getattr(model_param, '_with_fp32_optimizer', False):
main_param = model_param.detach().clone().float()
fp32_param_group['params'].append(main_param)
self._fp32_optim_main_params[model_param] = main_param
else:
distopt_param_group['params'].append(model_param)
distopt_param_groups.append(distopt_param_group)
fp32_param_groups.append(fp32_param_group)
# Construct explicit FP32 optimizer
adamw_kwargs = {}
for name in ('lr', 'betas', 'eps', 'weight_decay', 'amsgrad'):
if name in kwargs:
adamw_kwargs[name] = kwargs[name]
self._fp32_optim = torch.optim.AdamW(fp32_param_groups, **adamw_kwargs)
self._fp32_optim_grad_sync_needed = True
# Construct distributed optimizer
super().__init__(distopt_param_groups, **kwargs)
def _make_post_backward_hook(self, param, param_group_id, param_id):
def hook(*unused):
if getattr(param, '_pre_forward_hook_is_enabled', False):
raise RuntimeError(
'A parameter called its post-backward hook '
'before its pre-forward hook. '
'Please manually interact with the parameter '
'before the forward pass (e.g. by calling data_ptr) '
'or run DistributedFusedAdam with overlap_param_sync=False.'
)
with self._lock:
need_to_initialize = 'fragments' not in self.state[param]
if need_to_initialize:
self._init_param_state(param, param_group_id, param_id)
if self.greedy_grad_copy and not getattr(param, '_disable_greedy_grad_copy', False):
self._grad_copy(param)
if self.overlap_grad_sync and not getattr(param, '_disable_overlap_grad_sync', False):
self._try_start_bucket_grad_sync(
params=[param], ignore_last_bucket=need_to_initialize,
)
return hook
def _filter_distopt_params(self, params):
if self._fp32_optim is None:
return params
if params is None:
return None
if isinstance(params, torch.Tensor):
params = [params]
return filter(lambda param: param not in self._fp32_optim_main_params, params)
def parameters(self, with_fp32_optim_params=False):
if with_fp32_optim_params and self._fp32_optim is not None:
return itertools.chain(super().parameters(), self._fp32_optim_main_params.keys())
else:
return super().parameters()
def init_params(self, params=None):
super().init_params(self._filter_distopt_params(params))
def init_params_bucket(self, params):
super().init_params_bucket(self._filter_distopt_params(params))
def try_grad_sync(self, params):
params = self._filter_distopt_params(params)
params = [p for p in params if not getattr(p, '_disable_greedy_grad_copy', False)]
params = [p for p in params if not getattr(p, '_disable_overlap_grad_sync', False)]
for p in params:
self._grad_copy(p)
self._try_start_bucket_grad_sync(params=params)
def _try_start_bucket_param_sync(self, params=None):
super()._try_start_bucket_param_sync(self._filter_distopt_params(params))
def _fp32_optim_grad_sync(self):
if self._fp32_optim is None or not self._fp32_optim_grad_sync_needed:
return
for model_param, main_param in self._fp32_optim_main_params.items():
if model_param.grad is not None:
main_param.grad += model_param.grad.detach()
sync_requests = []
with _coalescing_manager(self.process_group, self.device, sync_requests):
for main_param in self._fp32_optim_main_params.values():
sync_requests.append(
torch.distributed.all_reduce(
main_param.grad, op=torch.distributed.ReduceOp.AVG, group=self.process_group, async_op=True,
)
)
for req in sync_requests:
req.wait()
self._fp32_optim_grad_sync_needed = False
def zero_grad(self, *args, **kwargs):
super().zero_grad(*args, **kwargs)
# Reset grads for explicit FP32 optimizer
if self._fp32_optim is not None:
self._fp32_optim_grad_sync_needed = True
self._fp32_optim.zero_grad(set_to_none=False)
for model_param, main_param in self._fp32_optim_main_params.items():
if main_param.grad is None:
main_param.grad = torch.zeros_like(main_param)
if model_param.grad is not None:
model_param.grad.zero_()
model_param.main_grad = main_param.grad
# Reset main grads
if self.contiguous_grad_buffer:
for param in self.parameters():
with _disable_pre_forward_hook(param):
param.main_grad = self.grad_buffer_view(param)
def grad_norm(self, parameters=None, norm_type=2.0, force=False):
assert norm_type == 2
if parameters is not None:
# Make sure we can access iterable multiple times
parameters = list(parameters)
# Compute grad norm
if force or self._grad_norm is None:
# Compute norm of local gradients for distributed optimizer
grad_norm_sq = self._local_grad_norm(
parameters=self._filter_distopt_params(parameters), norm_type=norm_type,
)
if self.redundant_size > 1:
grad_norm_sq /= self.redundant_size
# Compute norm of local gradients for explicit FP32 optimizer
if self._fp32_optim is not None:
self._fp32_optim_grad_sync()
if parameters is None:
for main_param in self._fp32_optim_main_params.values():
grad_norm_sq += torch.linalg.norm(main_param.grad) ** 2 / self.process_group_size
else:
for model_param in parameters:
if model_param in self._fp32_optim_main_params:
main_param = self._fp32_optim_main_params[model_param]
grad_norm_sq += torch.linalg.norm(main_param.grad) ** 2 / self.process_group_size
# Sum over all procs to get grad norm
torch.distributed.all_reduce(
grad_norm_sq, op=torch.distributed.ReduceOp.SUM,
)
self._grad_norm = grad_norm_sq.sqrt()
# Use cached grad norm
return super().grad_norm()
def step(self, closure=None, *, grad_scaler=None):
# Apply distributed optimizer
loss = super().step(closure=closure, grad_scaler=grad_scaler)
if self._fp32_optim is not None:
# Handle grad scaling
if grad_scaler is not None:
scaler_state = grad_scaler._per_optimizer_states[id(self)]
for _, found_inf in scaler_state['found_inf_per_device'].items():
if found_inf.item():
return loss
# Update learning rate
for distopt_group, fp32_optim_group in zip(self.param_groups, self._fp32_optim.param_groups):
fp32_optim_group['lr'] = distopt_group['lr']
# Apply explicit FP32 optimizer
self._fp32_optim_grad_sync()
for main_param in self._fp32_optim_main_params.values():
main_param.grad *= self._grad_scale
self._fp32_optim.step()
for model_param, main_param in self._fp32_optim_main_params.items():
model_param.detach().copy_(main_param.detach())
return loss
def state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
if self._fp32_optim is not None and state_dict is not None:
state_dict['fp32_optim'] = self._fp32_optim.state_dict()
state_dict['fp32_optim_fp32_params'] = list(self._fp32_optim_main_params.values())
return state_dict
def load_state_dict(self, state_dict):
if self._fp32_optim is not None and 'fp32_optim' in state_dict:
self._fp32_optim.load_state_dict(state_dict['fp32_optim'])
del state_dict['fp32_optim']
for old_main_param, new_main_param in zip(
self._fp32_optim_main_params.values(), state_dict['fp32_optim_fp32_params']
):
old_main_param.copy_(new_main_param.detach())
del state_dict['fp32_optim_fp32_params']
return super().load_state_dict(state_dict)
|