DMP-PCFC / trainer.py
XingyuLiang's picture
Upload 75 files
cd1df48 verified
import torch.optim as optim
import util
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau, OneCycleLR, CosineAnnealingLR, ExponentialLR
class Trainer():
def __init__(self, model, lrate, wdecay, clip, step_size, seq_out_len, scaler, device, cl=True):
self.scaler = scaler
self.model = model
self.model.to(device)
self.optimizer = optim.Adam(self.model.parameters(), lr=lrate, weight_decay=wdecay)
self.loss = util.masked_mae
self.clip = clip
self.step = step_size
self.iter = 1
self.task_level = 1
self.seq_out_len = seq_out_len
self.cl = cl
def train(self, input, real_val, idx=None):
self.model.train()
self.optimizer.zero_grad()
output = self.model(input, idx=idx)
output = output.transpose(1,3)
real = torch.unsqueeze(real_val,dim=1)
predict = self.scaler.inverse_transform(output)
if self.iter%self.step==0 and self.task_level<=self.seq_out_len:
self.task_level +=1
if self.cl:
loss = self.loss(predict[:, :, :, :self.task_level], real[:, :, :, :self.task_level], 0.0)
else:
loss = self.loss(predict, real, 0.0)
loss.backward()
if self.clip is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
self.optimizer.step()
# mae = util.masked_mae(predict,real,0.0).item()
mape = util.masked_mape(predict,real,0.0).item()
rmse = util.masked_rmse(predict,real,0.0).item()
self.iter += 1
return loss.item(),mape,rmse
def eval(self, input, real_val):
self.model.eval()
output = self.model(input)
output = output.transpose(1,3)
real = torch.unsqueeze(real_val,dim=1)
predict = self.scaler.inverse_transform(output)
loss = self.loss(predict, real, 0.0)
mape = util.masked_mape(predict,real,0.0).item()
rmse = util.masked_rmse(predict,real,0.0).item()
return loss.item(),mape,rmse
class Optim(object):
def _makeOptimizer(self):
if self.method == 'sgd':
self.optimizer = optim.SGD(self.params, lr=self.lr, weight_decay=self.lr_decay)
elif self.method == 'adagrad':
self.optimizer = optim.Adagrad(self.params, lr=self.lr, weight_decay=self.lr_decay)
elif self.method == 'adadelta':
self.optimizer = optim.Adadelta(self.params, lr=self.lr, weight_decay=self.lr_decay)
elif self.method == 'adam':
self.optimizer = optim.Adam(self.params, lr=self.lr, weight_decay=self.lr_decay)
elif self.method == 'Nadam':
self.optimizer = optim.NAdam(self.params, lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
elif self.method == 'adamw':
self.optimizer = optim.AdamW(self.params, lr=self.lr, weight_decay=self.lr_decay)
else:
raise RuntimeError("Invalid optim method: " + self.method)
def __init__(self, params, method, lr, clip, mode, factor, patience, steps_per_epoch, epochs, lr_decay=1, start_decay_at=None):
self.params = params # careful: params may be a generator
self.last_ppl = None
self.lr = lr
self.clip = clip
self.method = method
self.lr_decay = lr_decay
self.start_decay_at = start_decay_at
self.start_decay = False
self._makeOptimizer()
self.scheduler = ReduceLROnPlateau(self.optimizer, mode=mode, factor=factor, patience=patience, verbose=True)
# self.scheduler = ExponentialLR(self.optimizer,gamma=0.9, verbose=True)
# self.scheduler = CosineAnnealingLR(self.optimizer, max_lr=0.005, steps_per_epoch=steps_per_epoch, epochs=epochs, div_factor=1.1, final_div_factor=10)
def step(self):
# Compute gradients norm.
grad_norm = 0
if self.clip is not None:
torch.nn.utils.clip_grad_norm_(self.params, self.clip)
# for param in self.params:
# grad_norm += math.pow(param.grad.data.norm(), 2)
#
# grad_norm = math.sqrt(grad_norm)
# if grad_norm > 0:
# shrinkage = self.max_grad_norm / grad_norm
# else:
# shrinkage = 1.
#
# for param in self.params:
# if shrinkage < 1:
# param.grad.data.mul_(shrinkage)
self.optimizer.step()
return grad_norm
def lronplateau(self, loss):
self.scheduler.step(loss)
def EXlr(self):
self.scheduler.step()
def CosineAnnealingLR(self):
self.scheduler.step()
# decay learning rate if val perf does not improve or we hit the start_decay_at limit
def updateLearningRate(self, ppl, epoch):
if self.start_decay_at is not None and epoch >= self.start_decay_at:
self.start_decay = True
if self.last_ppl is not None and ppl > self.last_ppl:
self.start_decay = True
if self.start_decay:
self.lr = self.lr * self.lr_decay
print("Decaying learning rate to %g" % self.lr)
#only decay for one epoch
self.start_decay = False
self.last_ppl = ppl
self._makeOptimizer()