Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from collections.abc import Collection | |
from dataclasses import dataclass, field | |
from typing import List | |
import torch | |
from fairseq.dataclass import FairseqDataclass | |
from omegaconf import II, DictConfig | |
from torch.optim.optimizer import Optimizer, required | |
from . import FairseqOptimizer, register_optimizer | |
class FairseqNAGConfig(FairseqDataclass): | |
momentum: float = field(default=0.99, metadata={"help": "momentum factor"}) | |
weight_decay: float = field(default=0.0, metadata={"help": "weight decay"}) | |
# TODO common vars in parent class | |
lr: List[float] = II("optimization.lr") | |
class FairseqNAG(FairseqOptimizer): | |
def __init__(self, cfg: DictConfig, params): | |
super().__init__(cfg) | |
self._optimizer = NAG(params, **self.optimizer_config) | |
def optimizer_config(self): | |
""" | |
Return a kwarg dictionary that will be used to override optimizer | |
args stored in checkpoints. This allows us to load a checkpoint and | |
resume training using a different set of optimizer args, e.g., with a | |
different learning rate. | |
""" | |
return { | |
"lr": self.cfg.lr[0] | |
if isinstance(self.cfg.lr, Collection) | |
else self.cfg.lr, | |
"momentum": self.cfg.momentum, | |
"weight_decay": self.cfg.weight_decay, | |
} | |
class NAG(Optimizer): | |
def __init__(self, params, lr=required, momentum=0, weight_decay=0): | |
defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay) | |
super(NAG, self).__init__(params, defaults) | |
def supports_memory_efficient_fp16(self): | |
return True | |
def supports_flat_params(self): | |
return True | |
def step(self, closure=None): | |
"""Performs a single optimization step. | |
Args: | |
closure (callable, optional): A closure that reevaluates the model | |
and returns the loss. | |
""" | |
loss = None | |
if closure is not None: | |
loss = closure() | |
for group in self.param_groups: | |
weight_decay = group["weight_decay"] | |
momentum = group["momentum"] | |
lr = group["lr"] | |
lr_old = group.get("lr_old", lr) | |
lr_correct = lr / lr_old if lr_old > 0 else lr | |
for p in group["params"]: | |
if p.grad is None: | |
continue | |
p_data_fp32 = p.data | |
if p_data_fp32.dtype in {torch.float16, torch.bfloat16}: | |
p_data_fp32 = p_data_fp32.float() | |
d_p = p.grad.data.float() | |
param_state = self.state[p] | |
if "momentum_buffer" not in param_state: | |
param_state["momentum_buffer"] = torch.zeros_like(d_p) | |
else: | |
param_state["momentum_buffer"] = param_state["momentum_buffer"].to( | |
d_p | |
) | |
buf = param_state["momentum_buffer"] | |
if weight_decay != 0: | |
p_data_fp32.mul_(1 - lr * weight_decay) | |
p_data_fp32.add_(buf, alpha=momentum * momentum * lr_correct) | |
p_data_fp32.add_(d_p, alpha=-(1 + momentum) * lr) | |
buf.mul_(momentum * lr_correct).add_(d_p, alpha=-lr) | |
if p.data.dtype in {torch.float16, torch.bfloat16}: | |
p.data.copy_(p_data_fp32) | |
group["lr_old"] = lr | |
return loss | |