Spaces:
Sleeping
Sleeping
File size: 567 Bytes
e3641b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import torch.optim as optim
class LayerDecayOptimizer:
def __init__(self, optimizer, layerwise_decay_rate):
self.optimizer = optimizer
self.layerwise_decay_rate = layerwise_decay_rate
self.param_groups = optimizer.param_groups
def step(self, *args, **kwargs):
for i, group in enumerate(self.optimizer.param_groups):
group['lr'] *= self.layerwise_decay_rate[i]
self.optimizer.step(*args, **kwargs)
def zero_grad(self, *args, **kwargs):
self.optimizer.zero_grad(*args, **kwargs) |