Spaces:
Running
Running
File size: 4,608 Bytes
bcc039b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
import math
from functools import partial
from pydantic import BaseModel, ConfigDict
from torch import nn
from torch.optim import AdamW, lr_scheduler
logger = logging.getLogger()
class OptimArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
lr: float = 3e-4
weight_decay: float = 0.1
epsilon: float = 1e-8
beta1: float = 0.9
beta2: float = 0.95
clip: float = 1.0
scheduler: str = "cosine"
warmup: int = 2000
lr_min_ratio: float = 0.1
cycle_length: float = 1.0
cosine_theta: float = 1.0
annealing_step: int = 1000
decay_fraction: float = 0.1
exp_factor: float = 0.5
def lr_linear(step: int, warmup: int, n_steps: int, min_ratio: float) -> float:
if step < warmup:
lr = float(step) / warmup
elif step <= n_steps:
s = float(step - warmup) / (n_steps - warmup)
lr = s * min_ratio + (1 - s)
else:
lr = min_ratio
return lr
def lr_inv_sqrt(step: int, warmup: int, exp_factor: float, min_ratio: float) -> float:
if step < warmup:
lr = float(step) / warmup
else:
lr = max((warmup**exp_factor) / (step**exp_factor), min_ratio)
return lr
def lr_cosine(
step: int,
warmup: int,
n_steps: int,
cycle_length: float,
theta: float,
min_ratio: float,
) -> float:
sign = ((step // (n_steps * cycle_length)) % 2) * -2 + 1
if step < warmup:
lr = float(step) / warmup
elif step <= n_steps:
s = float(step - warmup) / (n_steps - warmup)
lr = min_ratio + 0.5 * (1 - min_ratio) * (
sign * math.cos(math.pi * s**theta / cycle_length) + 1
)
else:
lr = min_ratio
return lr
def lr_wsd(
step: int,
warmup: int,
n_steps: int,
decay_fraction: float,
cycle_length: float,
min_ratio: float,
) -> float:
"""
UNDERSTANDING WARMUP-STABLE-DECAY LEARNING RATES: A RIVER VALLEY LOSS LANDSCAPE PERSPECTIVE
https://arxiv.org/pdf/2410.05192
"""
cycle_num = step // int(n_steps * cycle_length) + 1
curr_n_steps = int(n_steps * cycle_length) * cycle_num
decay_length = int(curr_n_steps * decay_fraction)
if step < warmup:
lr = float(step) / warmup
elif step <= curr_n_steps - decay_length:
lr = 1.0
elif step > curr_n_steps - decay_length and step <= curr_n_steps:
# Linear interpolation gives similar results
# slope = -(1.0 - min_ratio) / decay_length
# intercept = min_ratio + ((1.0 - min_ratio) * curr_n_steps) / decay_length
# lr = slope * step + intercept
step = step - (curr_n_steps - decay_length)
lr = 1 / ((step / curr_n_steps) * (1 / min_ratio) + (1 - step / curr_n_steps))
else:
lr = min_ratio
return lr
def build_lr_fn(args: OptimArgs, n_steps: int):
if args.scheduler == "constant":
lr_fn = lambda x: 1.0
elif args.scheduler == "linear":
lr_fn = partial(
lr_linear, warmup=args.warmup, n_steps=n_steps, min_ratio=args.lr_min_ratio
)
elif args.scheduler == "inv_sqrt":
lr_fn = partial(
lr_inv_sqrt,
warmup=args.warmup,
exp_factor=args.exp_factor,
min_ratio=args.lr_min_ratio,
)
elif args.scheduler == "cosine":
lr_fn = partial(
lr_cosine,
warmup=args.warmup,
n_steps=n_steps,
cycle_length=args.cycle_length,
theta=args.cosine_theta,
min_ratio=args.lr_min_ratio,
)
elif args.scheduler == "wsd":
assert args.decay_fraction < args.cycle_length
lr_fn = partial(
lr_wsd,
warmup=args.warmup,
n_steps=n_steps,
decay_fraction=args.decay_fraction,
cycle_length=args.cycle_length,
min_ratio=args.lr_min_ratio,
)
else:
raise NotImplementedError(f"Unknown scheduler: {args.scheduler}")
return lr_fn
def build_optimizer(model: nn.Module, args: OptimArgs, n_steps: int):
logger.info("Starting build of optimizer...")
optimizer = AdamW(
model.parameters(),
lr=args.lr,
betas=(args.beta1, args.beta2),
weight_decay=args.weight_decay,
eps=args.epsilon,
fused=True, # Faster optim.step but can throw errors
)
# scheduler
lr_fn = build_lr_fn(args, n_steps)
scheduler = lr_scheduler.LambdaLR(optimizer, lr_fn)
logger.info("Done with build of optimizer.")
return optimizer, scheduler
|