Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import importlib | |
from collections.abc import Collection | |
from dataclasses import dataclass, field | |
from typing import List | |
import torch | |
from fairseq.dataclass import FairseqDataclass | |
from fairseq.optim import FairseqOptimizer, register_optimizer | |
from omegaconf import II, DictConfig | |
try: | |
from deepspeed.ops.op_builder import CPUAdamBuilder | |
has_deepspeed_cpu_adam = True | |
except ImportError: | |
has_deepspeed_cpu_adam = False | |
class FairseqCPUAdamConfig(FairseqDataclass): | |
adam_betas: str = field( | |
default="(0.9, 0.999)", metadata={"help": "betas for Adam optimizer"} | |
) | |
adam_eps: float = field( | |
default=1e-8, metadata={"help": "epsilon for Adam optimizer"} | |
) | |
weight_decay: float = field(default=0.0, metadata={"help": "weight decay"}) | |
fp16_adam_stats: bool = field( | |
default=False, metadata={"help": "use FP16 stats (with automatic scaling)"} | |
) | |
# TODO common vars below in parent | |
lr: List[float] = II("optimization.lr") | |
class FairseqCPUAdam(FairseqOptimizer): | |
"""Adam optimizer for fairseq, optimized for CPU tensors. | |
Important note: this optimizer corresponds to the "AdamW" variant of | |
Adam in its weight decay behavior. As such, it is most closely | |
analogous to torch.optim.AdamW from PyTorch. | |
""" | |
def __init__(self, cfg: DictConfig, params): | |
super().__init__(cfg) | |
self._optimizer = CPUAdam(params, **self.optimizer_config) | |
def optimizer_config(self): | |
""" | |
Return a kwarg dictionary that will be used to override optimizer | |
args stored in checkpoints. This allows us to load a checkpoint and | |
resume training using a different set of optimizer args, e.g., with a | |
different learning rate. | |
""" | |
return { | |
"lr": self.cfg.lr[0] | |
if isinstance(self.cfg.lr, Collection) | |
else self.cfg.lr, | |
"betas": eval(self.cfg.adam_betas), | |
"eps": self.cfg.adam_eps, | |
"weight_decay": self.cfg.weight_decay, | |
"use_fp16_stats": self.cfg.fp16_adam_stats, | |
} | |
class CPUAdam(torch.optim.Optimizer): | |
optimizer_id = 0 | |
def __init__( | |
self, | |
params, | |
lr=1e-3, | |
bias_correction=True, | |
betas=(0.9, 0.999), | |
eps=1e-8, | |
weight_decay=0, | |
use_fp16_stats=False, | |
): | |
defaults = { | |
"lr": lr, | |
"bias_correction": bias_correction, | |
"betas": betas, | |
"eps": eps, | |
"weight_decay": weight_decay, | |
} | |
super().__init__(params, defaults) | |
self.use_fp16_stats = use_fp16_stats | |
self.FLOAT16_MAX = 65504.0 | |
if not has_deepspeed_cpu_adam: | |
raise ImportError("Please install DeepSpeed: pip install deepspeed") | |
self.opt_id = CPUAdam.optimizer_id | |
CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1 | |
self.ds_opt_adam = CPUAdamBuilder().load() | |
adamw_mode = True | |
self.ds_opt_adam.create_adam( | |
self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode | |
) | |
def step(self, closure=None): | |
loss = None | |
if closure is not None: | |
with torch.enable_grad(): | |
loss = closure() | |
for group_id, group in enumerate(self.param_groups): | |
for param_id, p in enumerate(group["params"]): | |
if p.grad is None: | |
continue | |
state = self.state[p] | |
if len(state) == 0: | |
state["step"] = 0 | |
dtype = torch.float16 if self.use_fp16_stats else p.data.dtype | |
# gradient momentums | |
state["exp_avg"] = torch.zeros_like( | |
p.data, dtype=dtype, device="cpu" | |
) | |
# gradient variances | |
state["exp_avg_sq"] = torch.zeros_like( | |
p.data, dtype=dtype, device="cpu" | |
) | |
if self.use_fp16_stats: | |
assert torch.is_floating_point(p.data) | |
state["exp_avg_scale"] = 1.0 | |
state["exp_avg_sq_scale"] = 1.0 | |
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] | |
p_data_bak = p.data # backup of the original data pointer | |
p.data = p.data.to(dtype=torch.float32, device="cpu") | |
p.grad.data = p.grad.data.to(dtype=torch.float32, device="cpu") | |
if self.use_fp16_stats: | |
exp_avg = exp_avg.float() * state["exp_avg_scale"] | |
exp_avg_sq = exp_avg_sq.float() * state["exp_avg_sq_scale"] | |
state["step"] += 1 | |
beta1, beta2 = group["betas"] | |
self.ds_opt_adam.adam_update( | |
self.opt_id, | |
state["step"], | |
group["lr"], | |
beta1, | |
beta2, | |
group["eps"], | |
group["weight_decay"], | |
group["bias_correction"], | |
p.data, | |
p.grad.data, | |
exp_avg, | |
exp_avg_sq, | |
) | |
if p_data_bak.data_ptr() != p.data.data_ptr(): | |
p_data_bak.copy_(p.data) | |
p.data = p_data_bak | |
if self.use_fp16_stats: | |
def inf_norm(t): | |
return torch.norm(t, float("inf")) | |
# from github.com/openai/jukebox/blob/master/jukebox/utils/fp16.py | |
state["exp_avg_scale"], state["exp_avg_sq_scale"] = ( | |
1e-8 + inf_norm(exp_avg) / self.FLOAT16_MAX, | |
1e-8 + inf_norm(exp_avg_sq) / self.FLOAT16_MAX, | |
) | |
state["exp_avg"], state["exp_avg_sq"] = ( | |
(exp_avg / state["exp_avg_scale"]).half(), | |
(exp_avg_sq / state["exp_avg_sq_scale"]).half(), | |
) | |
return loss | |