from collections import OrderedDict import math from typing import List import torch from toolkit.optimizers.optimizer_utils import Auto8bitTensor, copy_stochastic, stochastic_grad_accummulation from optimum.quanto import QBytesTensor import random class Automagic(torch.optim.Optimizer): def __init__( self, params, lr=None, min_lr=1e-7, max_lr=1e-3, lr_pump_scale=1.1, lr_dump_scale=0.85, eps=(1e-30, 1e-3), clip_threshold=1.0, decay_rate=-0.8, weight_decay=0.0, do_paramiter_swapping=False, paramiter_swapping_factor=0.1, ): self.lr = lr self.min_lr = min_lr self.max_lr = max_lr self.lr_pump_scale = lr_pump_scale self.lr_dump_scale = lr_dump_scale defaults = { "lr": lr, "eps": eps, "clip_threshold": clip_threshold, "decay_rate": decay_rate, "weight_decay": weight_decay, } super().__init__(params, defaults) self.base_lrs: List[float] = [ lr for group in self.param_groups ] self.is_stochastic_rounding_accumulation = False # setup stochastic grad accum hooks for group in self.param_groups: for param in group['params']: if param.requires_grad and param.dtype != torch.float32: self.is_stochastic_rounding_accumulation = True param.register_post_accumulate_grad_hook( stochastic_grad_accummulation ) self.do_paramiter_swapping = do_paramiter_swapping self.paramiter_swapping_factor = paramiter_swapping_factor self._total_paramiter_size = 0 # count total paramiters for group in self.param_groups: for param in group['params']: self._total_paramiter_size += torch.numel(param) # pretty print total paramiters with comma seperation print(f"Total training paramiters: {self._total_paramiter_size:,}") # needs to be enabled to count paramiters if self.do_paramiter_swapping: self.enable_paramiter_swapping(self.paramiter_swapping_factor) def enable_paramiter_swapping(self, paramiter_swapping_factor=0.1): self.do_paramiter_swapping = True self.paramiter_swapping_factor = paramiter_swapping_factor # call it an initial time self.swap_paramiters() def swap_paramiters(self): all_params = [] # deactivate all paramiters for group in self.param_groups: for param in group['params']: param.requires_grad_(False) # remove any grad param.grad = None all_params.append(param) # shuffle all paramiters random.shuffle(all_params) # keep activating paramiters until we are going to go over the target paramiters target_paramiters = int( self._total_paramiter_size * self.paramiter_swapping_factor) total_paramiters = 0 for param in all_params: total_paramiters += torch.numel(param) if total_paramiters >= target_paramiters: break else: param.requires_grad_(True) @staticmethod def _get_lr(param_group, param_state): if 'avg_lr' in param_state: lr = param_state["avg_lr"] else: lr = 0.0 return lr def _get_group_lr(self, group): group_lrs = [] for p in group["params"]: group_lrs.append(self._get_lr(group, self.state[p])) # return avg if len(group_lrs) == 0: return self.lr return sum(group_lrs) / len(group_lrs) @staticmethod def _rms(tensor): return tensor.norm(2) / (tensor.numel() ** 0.5) @staticmethod def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): # copy from fairseq's adafactor implementation: # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=- 1, keepdim=True)).rsqrt_().unsqueeze(-1) c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() return torch.mul(r_factor, c_factor) def step_hook(self): if not self.is_stochastic_rounding_accumulation: return # copy over stochastically rounded grads for group in self.param_groups: for param in group['params']: if param.requires_grad and hasattr(param, "_accum_grad"): param.grad = param._accum_grad del param._accum_grad # adafactor manages its own lr def get_learning_rates(self): lrs = [ self._get_group_lr(group) for group in self.param_groups ] if len(lrs) == 0: lrs = self.base_lrs # if called before stepping return lrs def get_avg_learning_rate(self): lrs = self.get_learning_rates() return sum(lrs) / len(lrs) @torch.no_grad() def step(self, closure=None): """ Performs a single optimization step Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ self.step_hook() loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None or not p.requires_grad: continue grad = p.grad if grad.dtype != torch.float32: grad = grad.to(torch.float32) if grad.is_sparse: raise RuntimeError( "Automagic does not support sparse gradients.") state = self.state[p] grad_shape = grad.shape factored = len(grad_shape) >= 2 # State Initialization if len(state) == 0: self.initialize_state(p) else: if factored: state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to( grad) state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to( grad) else: state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) p_data_fp32 = p if isinstance(p_data_fp32, QBytesTensor): p_data_fp32 = p_data_fp32.dequantize() if p.dtype != torch.float32: p_data_fp32 = p_data_fp32.clone().float() state["step"] += 1 state["RMS"] = self._rms(p_data_fp32) # lr = self._get_lr(group, state) beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) eps = group["eps"] if isinstance(eps, tuple) or isinstance(eps, list): eps = eps[0] update = (grad**2) + eps if factored: exp_avg_sq_row = state["exp_avg_sq_row"] exp_avg_sq_col = state["exp_avg_sq_col"] exp_avg_sq_row.mul_(beta2t).add_( update.mean(dim=-1), alpha=(1.0 - beta2t)) exp_avg_sq_col.mul_(beta2t).add_( update.mean(dim=-2), alpha=(1.0 - beta2t)) # Approximation of exponential moving average of square of gradient update = self._approx_sq_grad( exp_avg_sq_row, exp_avg_sq_col) update.mul_(grad) else: exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) update = exp_avg_sq.rsqrt().mul_(grad) update.div_( (self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) # calculate new lr mask. if the updated param is going in same direction, increase lr, else decrease # update the lr mask. self.lr_momentum is < 1.0. If a paramiter is positive and increasing (or negative and decreasing), increase lr, # for that single paramiter. If a paramiter is negative and increasing or positive and decreasing, decrease lr for that single paramiter. # to decrease lr, multiple by self.lr_momentum, to increase lr, divide by self.lr_momentum. # not doing it this way anymore # update.mul_(lr) # Get signs of current last update and updates last_polarity = state['last_polarity'] current_polarity = (update > 0).to(torch.bool) sign_agreement = torch.where( last_polarity == current_polarity, 1, -1) state['last_polarity'] = current_polarity lr_mask = state['lr_mask'].to(torch.float32) # Update learning rate mask based on sign agreement new_lr = torch.where( sign_agreement > 0, lr_mask * self.lr_pump_scale, # Increase lr lr_mask * self.lr_dump_scale # Decrease lr ) # Clip learning rates to bounds new_lr = torch.clamp( new_lr, min=self.min_lr, max=self.max_lr ) # Apply the learning rate mask to the update update.mul_(new_lr) state['lr_mask'] = Auto8bitTensor(new_lr) state['avg_lr'] = torch.mean(new_lr) if group["weight_decay"] != 0: p_data_fp32.add_( p_data_fp32, alpha=(-group["weight_decay"] * new_lr)) p_data_fp32.add_(-update) if p.dtype != torch.float32: # apply stochastic rounding copy_stochastic(p, p_data_fp32) return loss def initialize_state(self, p): state = self.state[p] state["step"] = 0 # store the lr mask if 'lr_mask' not in state: state['lr_mask'] = Auto8bitTensor(torch.ones( p.shape).to(p.device, dtype=torch.float32) * self.lr ) state['avg_lr'] = torch.mean( state['lr_mask'].to(torch.float32)) if 'last_polarity' not in state: state['last_polarity'] = torch.zeros( p.shape, dtype=torch.bool, device=p.device) factored = len(p.shape) >= 2 if factored: state["exp_avg_sq_row"] = torch.zeros( p.shape[:-1]).to(p) state["exp_avg_sq_col"] = torch.zeros( p.shape[:-2] + p.shape[-1:]).to(p) else: state["exp_avg_sq"] = torch.zeros_like(p) state["RMS"] = 0 # override the state_dict to save the lr_mask def state_dict(self, *args, **kwargs): orig_state_dict = super().state_dict(*args, **kwargs) # convert the state to quantized tensor to scale and quantized new_sace_state = {} for p, state in orig_state_dict['state'].items(): save_state = {k: v for k, v in state.items() if k != 'lr_mask'} save_state['lr_mask'] = state['lr_mask'].state_dict() new_sace_state[p] = save_state orig_state_dict['state'] = new_sace_state return orig_state_dict def load_state_dict(self, state_dict, strict=True): # load the lr_mask from the state_dict # dont load state dict for now. Has a bug. Need to fix it. return idx = 0 for group in self.param_groups: for p in group['params']: self.initialize_state(p) state = self.state[p] m = state_dict['state'][idx]['lr_mask'] sd_mask = m['quantized'].to(m['orig_dtype']) * m['scale'] state['lr_mask'] = Auto8bitTensor(sd_mask) del state_dict['state'][idx]['lr_mask'] idx += 1 super().load_state_dict(state_dict)