PromptNet / modules /trainer.py
fenglinliu's picture
Upload 55 files
6e32a75 verified
import os
from abc import abstractmethod
import json
import time
import torch
import pandas as pd
from numpy import inf
class BaseTrainer(object):
def __init__(self, model, criterion, metric_ftns, optimizer, args):
self.args = args
# setup GPU device if available, move model into configured device
self.device, device_ids = self._prepare_device(args.n_gpu)
self.model = model.to(self.device)
if len(device_ids) > 1:
self.model = torch.nn.DataParallel(model, device_ids=device_ids)
self.criterion = criterion
self.metric_ftns = metric_ftns
self.optimizer = optimizer
self.epochs = self.args.epochs
self.save_period = self.args.save_period
self.mnt_mode = args.monitor_mode
self.mnt_metric = 'val_' + args.monitor_metric
self.mnt_metric_test = 'test_' + args.monitor_metric
assert self.mnt_mode in ['min', 'max']
self.mnt_best = inf if self.mnt_mode == 'min' else -inf
self.early_stop = getattr(self.args, 'early_stop', inf)
self.start_epoch = 1
self.checkpoint_dir = args.save_dir
if not os.path.exists(self.checkpoint_dir):
os.makedirs(self.checkpoint_dir)
if args.resume is not None:
self._resume_checkpoint(args.resume)
self.best_recorder = {'val': {self.mnt_metric: self.mnt_best},
'test': {self.mnt_metric_test: self.mnt_best}}
@abstractmethod
def _train_epoch(self, epoch):
raise NotImplementedError
def train(self):
not_improved_count = 0
for epoch in range(self.start_epoch, self.epochs + 1):
result = self._train_epoch(epoch)
# save logged informations into log dict
log = {'epoch': epoch}
log.update(result)
self._record_best(log)
# print logged informations to the screen
for key, value in log.items():
print('\t{:15s}: {}'.format(str(key), value))
# evaluate model performance according to configured metric, save best checkpoint as model_best
best = False
if self.mnt_mode != 'off':
try:
# check whether model performance improved or not, according to specified metric(mnt_metric)
improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
(self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
except KeyError:
print("Warning: Metric '{}' is not found. " "Model performance monitoring is disabled.".format(
self.mnt_metric))
self.mnt_mode = 'off'
improved = False
if improved:
self.mnt_best = log[self.mnt_metric]
not_improved_count = 0
best = True
else:
not_improved_count += 1
if not_improved_count > self.early_stop:
print("Validation performance didn\'t improve for {} epochs. " "Training stops.".format(
self.early_stop))
break
if epoch % self.save_period == 0:
self._save_checkpoint(epoch, save_best=best)
self._print_best()
self._print_best_to_file()
def _print_best_to_file(self):
crt_time = time.asctime(time.localtime(time.time()))
self.best_recorder['val']['time'] = crt_time
self.best_recorder['test']['time'] = crt_time
self.best_recorder['val']['seed'] = self.args.seed
self.best_recorder['test']['seed'] = self.args.seed
self.best_recorder['val']['best_model_from'] = 'val'
self.best_recorder['test']['best_model_from'] = 'test'
if not os.path.exists(self.args.record_dir):
os.makedirs(self.args.record_dir)
record_path = os.path.join(self.args.record_dir, self.args.dataset_name+'.csv')
if not os.path.exists(record_path):
record_table = pd.DataFrame()
else:
record_table = pd.read_csv(record_path)
record_table = record_table.append(self.best_recorder['val'], ignore_index=True)
record_table = record_table.append(self.best_recorder['test'], ignore_index=True)
record_table.to_csv(record_path, index=False)
def _prepare_device(self, n_gpu_use):
n_gpu = torch.cuda.device_count()
if n_gpu_use > 0 and n_gpu == 0:
print("Warning: There\'s no GPU available on this machine," "training will be performed on CPU.")
n_gpu_use = 0
if n_gpu_use > n_gpu:
print(
"Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format(
n_gpu_use, n_gpu))
n_gpu_use = n_gpu
device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
list_ids = list(range(n_gpu_use))
return device, list_ids
def _save_checkpoint(self, epoch, save_best=False):
state = {
'epoch': epoch,
'state_dict': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'monitor_best': self.mnt_best
}
filename = os.path.join(self.checkpoint_dir, 'current_checkpoint.pth')
torch.save(state, filename)
print("Saving checkpoint: {} ...".format(filename))
if save_best:
best_path = os.path.join(self.checkpoint_dir, 'model_best.pth')
torch.save(state, best_path)
print("Saving current best: model_best.pth ...")
def _resume_checkpoint(self, resume_path):
resume_path = str(resume_path)
print("Loading checkpoint: {} ...".format(resume_path))
checkpoint = torch.load(resume_path)
self.start_epoch = checkpoint['epoch'] + 1
self.mnt_best = checkpoint['monitor_best']
self.model.load_state_dict(checkpoint['state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
print("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
def _record_best(self, log):
improved_val = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.best_recorder['val'][
self.mnt_metric]) or \
(self.mnt_mode == 'max' and log[self.mnt_metric] >= self.best_recorder['val'][self.mnt_metric])
if improved_val:
self.best_recorder['val'].update(log)
improved_test = (self.mnt_mode == 'min' and log[self.mnt_metric_test] <= self.best_recorder['test'][
self.mnt_metric_test]) or \
(self.mnt_mode == 'max' and log[self.mnt_metric_test] >= self.best_recorder['test'][
self.mnt_metric_test])
if improved_test:
self.best_recorder['test'].update(log)
def _print_best(self):
print('Best results (w.r.t {}) in validation set:'.format(self.args.monitor_metric))
for key, value in self.best_recorder['val'].items():
print('\t{:15s}: {}'.format(str(key), value))
print('Best results (w.r.t {}) in test set:'.format(self.args.monitor_metric))
for key, value in self.best_recorder['test'].items():
print('\t{:15s}: {}'.format(str(key), value))
if not os.path.exists('valreports/'):
os.makedirs('valreports/')
if not os.path.exists('testreports/'):
os.makedirs('testreports/')
class Trainer(BaseTrainer):
def __init__(self, model, criterion, metric_ftns, optimizer, args, lr_scheduler, train_dataloader, val_dataloader,
test_dataloader):
super(Trainer, self).__init__(model, criterion, metric_ftns, optimizer, args)
self.lr_scheduler = lr_scheduler
self.train_dataloader = train_dataloader
self.val_dataloader = val_dataloader
self.test_dataloader = test_dataloader
def _train_epoch(self, epoch):
train_loss = 0
self.model.train()
for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.train_dataloader):
images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(self.device), reports_masks.to(
self.device)
output = self.model(images, reports_ids, mode='train')
loss = self.criterion(output, reports_ids, reports_masks)
train_loss += loss.item()
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(self.model.parameters(), 0.1)
self.optimizer.step()
log = {'train_loss': train_loss / len(self.train_dataloader)}
self.model.eval()
with torch.no_grad():
result_report_val = []
val_gts, val_res = [], []
for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.val_dataloader):
images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(
self.device), reports_masks.to(self.device)
output = self.model(images, mode='sample')
reports = self.model.tokenizer.decode_batch(output.cpu().numpy())
for i in range(reports_ids.shape[0]):
temp1 = {'reports_ids': images_id[i], 'reports': reports[i]}
result_report_val.append(temp1)
ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy())
val_res.extend(reports)
val_gts.extend(ground_truths)
val_met = self.metric_ftns({i: [gt] for i, gt in enumerate(val_gts)},
{i: [re] for i, re in enumerate(val_res)})
log.update(**{'val_' + k: v for k, v in val_met.items()})
resFileval = 'valreports/mixed-' + str(epoch) + '.json'
json.dump(result_report_val, open(resFileval, 'w'))
self.model.eval()
with torch.no_grad():
result_report_test = []
test_gts, test_res = [], []
for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.test_dataloader):
images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(
self.device), reports_masks.to(self.device)
output = self.model(images, mode='sample')
reports = self.model.tokenizer.decode_batch(output.cpu().numpy())
# print('reportsreportsreportsreports',images_id,reports)
for i in range(reports_ids.shape[0]):
temp = {'reports_ids': images_id[i], 'reports': reports[i]}
result_report_test.append(temp)
ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy())
test_res.extend(reports)
test_gts.extend(ground_truths)
test_met = self.metric_ftns({i: [gt] for i, gt in enumerate(test_gts)},
{i: [re] for i, re in enumerate(test_res)})
log.update(**{'test_' + k: v for k, v in test_met.items()})
resFiletest = 'testreports/mixed-' + str(epoch) + '.json'
json.dump(result_report_test, open(resFiletest, 'w'))
self.lr_scheduler.step()
return log