|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import collections |
|
import itertools |
|
|
|
import torch |
|
from apex.contrib.optimizers.distributed_fused_adam import ( |
|
DistributedFusedAdam, |
|
_coalescing_manager, |
|
_disable_pre_forward_hook, |
|
) |
|
from apex.transformer import parallel_state |
|
|
|
|
|
def _str_to_dtype(dtype): |
|
if isinstance(dtype, torch.dtype): |
|
return dtype |
|
name = str(dtype).strip().lower() |
|
if name in ('', 'none'): |
|
return torch.float32 |
|
elif name in ('torch.float32', 'float32', 'float', 'fp32', '32'): |
|
return torch.float32 |
|
elif name in ('torch.float16', 'float16', 'half', 'fp16', '16'): |
|
return torch.float16 |
|
elif name in ('torch.bfloat16', 'bfloat16', 'bf16'): |
|
return torch.bfloat16 |
|
else: |
|
raise ValueError(f'unsupported dtype ({dtype})') |
|
|
|
|
|
class MegatronDistributedFusedAdam(DistributedFusedAdam): |
|
"""Wrapper class that supports NeMo-Megatron optimizations |
|
|
|
When O2-style optimizations are enabled, gradients are accumulated |
|
into the main_grad buffer instead of the grad buffer. |
|
|
|
""" |
|
|
|
def __init__(self, params, disable_distributed_parameters=False, **kwargs): |
|
|
|
|
|
if 'process_group' not in kwargs and not parallel_state.is_unitialized(): |
|
kwargs['process_group'] = parallel_state.get_data_parallel_group() |
|
if disable_distributed_parameters: |
|
world_size = torch.distributed.get_world_size() |
|
rank = torch.distributed.get_rank() |
|
self_groups = [torch.distributed.new_group(ranks=[i]) for i in range(world_size)] |
|
kwargs['distributed_process_group'] = self_groups[rank] |
|
kwargs['redundant_process_group'] = kwargs['process_group'] |
|
|
|
|
|
for keyword in ('dtype', 'grad_sync_dtype', 'param_sync_dtype'): |
|
if keyword in kwargs: |
|
kwargs[keyword] = _str_to_dtype(kwargs[keyword]) |
|
|
|
|
|
param_groups = list(params) |
|
assert param_groups |
|
if not isinstance(param_groups[0], dict): |
|
param_groups = [{'params': param_groups}] |
|
|
|
|
|
self._fp32_optim = None |
|
distopt_param_groups = param_groups |
|
dtype = kwargs['dtype'] if 'dtype' in kwargs else torch.float32 |
|
grad_sync_dtype = kwargs['grad_sync_dtype'] if 'grad_sync_dtype' in kwargs else dtype |
|
needs_fp32_optimizer = any( |
|
getattr(param, '_with_fp32_optimizer', False) |
|
for param in itertools.chain.from_iterable(param_group['params'] for param_group in param_groups) |
|
) |
|
if (dtype != torch.float32 or grad_sync_dtype != torch.float32) and needs_fp32_optimizer: |
|
|
|
|
|
distopt_param_groups = [] |
|
fp32_param_groups = [] |
|
self._fp32_optim_main_params = collections.OrderedDict() |
|
for param_group in param_groups: |
|
distopt_param_group = {key: val for key, val in param_group.items() if key != 'params'} |
|
distopt_param_group['params'] = [] |
|
fp32_param_group = {key: val for key, val in param_group.items() if key != 'params'} |
|
fp32_param_group['params'] = [] |
|
for model_param in param_group['params']: |
|
if getattr(model_param, '_with_fp32_optimizer', False): |
|
main_param = model_param.detach().clone().float() |
|
fp32_param_group['params'].append(main_param) |
|
self._fp32_optim_main_params[model_param] = main_param |
|
else: |
|
distopt_param_group['params'].append(model_param) |
|
distopt_param_groups.append(distopt_param_group) |
|
fp32_param_groups.append(fp32_param_group) |
|
|
|
|
|
adamw_kwargs = {} |
|
for name in ('lr', 'betas', 'eps', 'weight_decay', 'amsgrad'): |
|
if name in kwargs: |
|
adamw_kwargs[name] = kwargs[name] |
|
self._fp32_optim = torch.optim.AdamW(fp32_param_groups, **adamw_kwargs) |
|
self._fp32_optim_grad_sync_needed = True |
|
|
|
|
|
super().__init__(distopt_param_groups, **kwargs) |
|
|
|
def _make_post_backward_hook(self, param, param_group_id, param_id): |
|
def hook(*unused): |
|
if getattr(param, '_pre_forward_hook_is_enabled', False): |
|
raise RuntimeError( |
|
'A parameter called its post-backward hook ' |
|
'before its pre-forward hook. ' |
|
'Please manually interact with the parameter ' |
|
'before the forward pass (e.g. by calling data_ptr) ' |
|
'or run DistributedFusedAdam with overlap_param_sync=False.' |
|
) |
|
with self._lock: |
|
need_to_initialize = 'fragments' not in self.state[param] |
|
if need_to_initialize: |
|
self._init_param_state(param, param_group_id, param_id) |
|
if self.greedy_grad_copy and not getattr(param, '_disable_greedy_grad_copy', False): |
|
self._grad_copy(param) |
|
if self.overlap_grad_sync and not getattr(param, '_disable_overlap_grad_sync', False): |
|
self._try_start_bucket_grad_sync( |
|
params=[param], ignore_last_bucket=need_to_initialize, |
|
) |
|
|
|
return hook |
|
|
|
def _filter_distopt_params(self, params): |
|
if self._fp32_optim is None: |
|
return params |
|
if params is None: |
|
return None |
|
if isinstance(params, torch.Tensor): |
|
params = [params] |
|
return filter(lambda param: param not in self._fp32_optim_main_params, params) |
|
|
|
def parameters(self, with_fp32_optim_params=False): |
|
if with_fp32_optim_params and self._fp32_optim is not None: |
|
return itertools.chain(super().parameters(), self._fp32_optim_main_params.keys()) |
|
else: |
|
return super().parameters() |
|
|
|
def init_params(self, params=None): |
|
super().init_params(self._filter_distopt_params(params)) |
|
|
|
def init_params_bucket(self, params): |
|
super().init_params_bucket(self._filter_distopt_params(params)) |
|
|
|
def try_grad_sync(self, params): |
|
params = self._filter_distopt_params(params) |
|
params = [p for p in params if not getattr(p, '_disable_greedy_grad_copy', False)] |
|
params = [p for p in params if not getattr(p, '_disable_overlap_grad_sync', False)] |
|
for p in params: |
|
self._grad_copy(p) |
|
self._try_start_bucket_grad_sync(params=params) |
|
|
|
def _try_start_bucket_param_sync(self, params=None): |
|
super()._try_start_bucket_param_sync(self._filter_distopt_params(params)) |
|
|
|
def _fp32_optim_grad_sync(self): |
|
if self._fp32_optim is None or not self._fp32_optim_grad_sync_needed: |
|
return |
|
for model_param, main_param in self._fp32_optim_main_params.items(): |
|
if model_param.grad is not None: |
|
main_param.grad += model_param.grad.detach() |
|
sync_requests = [] |
|
with _coalescing_manager(self.process_group, self.device, sync_requests): |
|
for main_param in self._fp32_optim_main_params.values(): |
|
sync_requests.append( |
|
torch.distributed.all_reduce( |
|
main_param.grad, op=torch.distributed.ReduceOp.AVG, group=self.process_group, async_op=True, |
|
) |
|
) |
|
for req in sync_requests: |
|
req.wait() |
|
self._fp32_optim_grad_sync_needed = False |
|
|
|
def zero_grad(self, *args, **kwargs): |
|
super().zero_grad(*args, **kwargs) |
|
|
|
|
|
if self._fp32_optim is not None: |
|
self._fp32_optim_grad_sync_needed = True |
|
self._fp32_optim.zero_grad(set_to_none=False) |
|
for model_param, main_param in self._fp32_optim_main_params.items(): |
|
if main_param.grad is None: |
|
main_param.grad = torch.zeros_like(main_param) |
|
if model_param.grad is not None: |
|
model_param.grad.zero_() |
|
model_param.main_grad = main_param.grad |
|
|
|
|
|
if self.contiguous_grad_buffer: |
|
for param in self.parameters(): |
|
with _disable_pre_forward_hook(param): |
|
param.main_grad = self.grad_buffer_view(param) |
|
|
|
def grad_norm(self, parameters=None, norm_type=2.0, force=False): |
|
assert norm_type == 2 |
|
|
|
if parameters is not None: |
|
|
|
parameters = list(parameters) |
|
|
|
|
|
if force or self._grad_norm is None: |
|
|
|
|
|
grad_norm_sq = self._local_grad_norm( |
|
parameters=self._filter_distopt_params(parameters), norm_type=norm_type, |
|
) |
|
if self.redundant_size > 1: |
|
grad_norm_sq /= self.redundant_size |
|
|
|
|
|
if self._fp32_optim is not None: |
|
self._fp32_optim_grad_sync() |
|
if parameters is None: |
|
for main_param in self._fp32_optim_main_params.values(): |
|
grad_norm_sq += torch.linalg.norm(main_param.grad) ** 2 / self.process_group_size |
|
else: |
|
for model_param in parameters: |
|
if model_param in self._fp32_optim_main_params: |
|
main_param = self._fp32_optim_main_params[model_param] |
|
grad_norm_sq += torch.linalg.norm(main_param.grad) ** 2 / self.process_group_size |
|
|
|
|
|
torch.distributed.all_reduce( |
|
grad_norm_sq, op=torch.distributed.ReduceOp.SUM, |
|
) |
|
self._grad_norm = grad_norm_sq.sqrt() |
|
|
|
|
|
return super().grad_norm() |
|
|
|
def step(self, closure=None, *, grad_scaler=None): |
|
|
|
|
|
loss = super().step(closure=closure, grad_scaler=grad_scaler) |
|
|
|
if self._fp32_optim is not None: |
|
|
|
|
|
if grad_scaler is not None: |
|
scaler_state = grad_scaler._per_optimizer_states[id(self)] |
|
for _, found_inf in scaler_state['found_inf_per_device'].items(): |
|
if found_inf.item(): |
|
return loss |
|
|
|
|
|
for distopt_group, fp32_optim_group in zip(self.param_groups, self._fp32_optim.param_groups): |
|
fp32_optim_group['lr'] = distopt_group['lr'] |
|
|
|
|
|
self._fp32_optim_grad_sync() |
|
for main_param in self._fp32_optim_main_params.values(): |
|
main_param.grad *= self._grad_scale |
|
self._fp32_optim.step() |
|
for model_param, main_param in self._fp32_optim_main_params.items(): |
|
model_param.detach().copy_(main_param.detach()) |
|
|
|
return loss |
|
|
|
def state_dict(self, *args, **kwargs): |
|
state_dict = super().state_dict(*args, **kwargs) |
|
if self._fp32_optim is not None and state_dict is not None: |
|
state_dict['fp32_optim'] = self._fp32_optim.state_dict() |
|
state_dict['fp32_optim_fp32_params'] = list(self._fp32_optim_main_params.values()) |
|
return state_dict |
|
|
|
def load_state_dict(self, state_dict): |
|
if self._fp32_optim is not None and 'fp32_optim' in state_dict: |
|
self._fp32_optim.load_state_dict(state_dict['fp32_optim']) |
|
del state_dict['fp32_optim'] |
|
for old_main_param, new_main_param in zip( |
|
self._fp32_optim_main_params.values(), state_dict['fp32_optim_fp32_params'] |
|
): |
|
old_main_param.copy_(new_main_param.detach()) |
|
del state_dict['fp32_optim_fp32_params'] |
|
return super().load_state_dict(state_dict) |
|
|