|
import torch.optim as optim
|
|
import util
|
|
import torch
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau, OneCycleLR, CosineAnnealingLR, ExponentialLR
|
|
class Trainer():
|
|
def __init__(self, model, lrate, wdecay, clip, step_size, seq_out_len, scaler, device, cl=True):
|
|
self.scaler = scaler
|
|
self.model = model
|
|
self.model.to(device)
|
|
self.optimizer = optim.Adam(self.model.parameters(), lr=lrate, weight_decay=wdecay)
|
|
self.loss = util.masked_mae
|
|
self.clip = clip
|
|
self.step = step_size
|
|
self.iter = 1
|
|
self.task_level = 1
|
|
self.seq_out_len = seq_out_len
|
|
self.cl = cl
|
|
|
|
def train(self, input, real_val, idx=None):
|
|
self.model.train()
|
|
self.optimizer.zero_grad()
|
|
output = self.model(input, idx=idx)
|
|
output = output.transpose(1,3)
|
|
real = torch.unsqueeze(real_val,dim=1)
|
|
predict = self.scaler.inverse_transform(output)
|
|
if self.iter%self.step==0 and self.task_level<=self.seq_out_len:
|
|
self.task_level +=1
|
|
if self.cl:
|
|
loss = self.loss(predict[:, :, :, :self.task_level], real[:, :, :, :self.task_level], 0.0)
|
|
else:
|
|
loss = self.loss(predict, real, 0.0)
|
|
|
|
loss.backward()
|
|
|
|
if self.clip is not None:
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
|
|
|
|
self.optimizer.step()
|
|
|
|
mape = util.masked_mape(predict,real,0.0).item()
|
|
rmse = util.masked_rmse(predict,real,0.0).item()
|
|
self.iter += 1
|
|
return loss.item(),mape,rmse
|
|
|
|
def eval(self, input, real_val):
|
|
self.model.eval()
|
|
output = self.model(input)
|
|
output = output.transpose(1,3)
|
|
real = torch.unsqueeze(real_val,dim=1)
|
|
predict = self.scaler.inverse_transform(output)
|
|
loss = self.loss(predict, real, 0.0)
|
|
mape = util.masked_mape(predict,real,0.0).item()
|
|
rmse = util.masked_rmse(predict,real,0.0).item()
|
|
return loss.item(),mape,rmse
|
|
|
|
|
|
|
|
class Optim(object):
|
|
|
|
def _makeOptimizer(self):
|
|
if self.method == 'sgd':
|
|
self.optimizer = optim.SGD(self.params, lr=self.lr, weight_decay=self.lr_decay)
|
|
elif self.method == 'adagrad':
|
|
self.optimizer = optim.Adagrad(self.params, lr=self.lr, weight_decay=self.lr_decay)
|
|
elif self.method == 'adadelta':
|
|
self.optimizer = optim.Adadelta(self.params, lr=self.lr, weight_decay=self.lr_decay)
|
|
elif self.method == 'adam':
|
|
self.optimizer = optim.Adam(self.params, lr=self.lr, weight_decay=self.lr_decay)
|
|
elif self.method == 'Nadam':
|
|
self.optimizer = optim.NAdam(self.params, lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
|
|
elif self.method == 'adamw':
|
|
self.optimizer = optim.AdamW(self.params, lr=self.lr, weight_decay=self.lr_decay)
|
|
else:
|
|
raise RuntimeError("Invalid optim method: " + self.method)
|
|
|
|
def __init__(self, params, method, lr, clip, mode, factor, patience, steps_per_epoch, epochs, lr_decay=1, start_decay_at=None):
|
|
self.params = params
|
|
self.last_ppl = None
|
|
self.lr = lr
|
|
self.clip = clip
|
|
self.method = method
|
|
self.lr_decay = lr_decay
|
|
self.start_decay_at = start_decay_at
|
|
self.start_decay = False
|
|
self._makeOptimizer()
|
|
self.scheduler = ReduceLROnPlateau(self.optimizer, mode=mode, factor=factor, patience=patience, verbose=True)
|
|
|
|
|
|
|
|
|
|
def step(self):
|
|
|
|
grad_norm = 0
|
|
if self.clip is not None:
|
|
torch.nn.utils.clip_grad_norm_(self.params, self.clip)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.optimizer.step()
|
|
return grad_norm
|
|
|
|
def lronplateau(self, loss):
|
|
self.scheduler.step(loss)
|
|
def EXlr(self):
|
|
self.scheduler.step()
|
|
def CosineAnnealingLR(self):
|
|
self.scheduler.step()
|
|
|
|
def updateLearningRate(self, ppl, epoch):
|
|
if self.start_decay_at is not None and epoch >= self.start_decay_at:
|
|
self.start_decay = True
|
|
if self.last_ppl is not None and ppl > self.last_ppl:
|
|
self.start_decay = True
|
|
|
|
if self.start_decay:
|
|
self.lr = self.lr * self.lr_decay
|
|
print("Decaying learning rate to %g" % self.lr)
|
|
|
|
self.start_decay = False
|
|
|
|
self.last_ppl = ppl
|
|
|
|
self._makeOptimizer()
|
|
|