File size: 567 Bytes
e3641b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch.optim as optim

class LayerDecayOptimizer:
    def __init__(self, optimizer, layerwise_decay_rate):
        self.optimizer = optimizer
        self.layerwise_decay_rate = layerwise_decay_rate
        self.param_groups = optimizer.param_groups
        
    def step(self, *args, **kwargs):
        for i, group in enumerate(self.optimizer.param_groups):
            group['lr'] *= self.layerwise_decay_rate[i]
        self.optimizer.step(*args, **kwargs)
        
    def zero_grad(self, *args, **kwargs):
        self.optimizer.zero_grad(*args, **kwargs)