Spaces:
Sleeping
Sleeping
import os | |
from abc import abstractmethod | |
import json | |
import time | |
import torch | |
import pandas as pd | |
from numpy import inf | |
class BaseTrainer(object): | |
def __init__(self, model, criterion, metric_ftns, optimizer, args): | |
self.args = args | |
# setup GPU device if available, move model into configured device | |
self.device, device_ids = self._prepare_device(args.n_gpu) | |
self.model = model.to(self.device) | |
if len(device_ids) > 1: | |
self.model = torch.nn.DataParallel(model, device_ids=device_ids) | |
self.criterion = criterion | |
self.metric_ftns = metric_ftns | |
self.optimizer = optimizer | |
self.epochs = self.args.epochs | |
self.save_period = self.args.save_period | |
self.mnt_mode = args.monitor_mode | |
self.mnt_metric = 'val_' + args.monitor_metric | |
self.mnt_metric_test = 'test_' + args.monitor_metric | |
assert self.mnt_mode in ['min', 'max'] | |
self.mnt_best = inf if self.mnt_mode == 'min' else -inf | |
self.early_stop = getattr(self.args, 'early_stop', inf) | |
self.start_epoch = 1 | |
self.checkpoint_dir = args.save_dir | |
if not os.path.exists(self.checkpoint_dir): | |
os.makedirs(self.checkpoint_dir) | |
if args.resume is not None: | |
self._resume_checkpoint(args.resume) | |
self.best_recorder = {'val': {self.mnt_metric: self.mnt_best}, | |
'test': {self.mnt_metric_test: self.mnt_best}} | |
def _train_epoch(self, epoch): | |
raise NotImplementedError | |
def train(self): | |
not_improved_count = 0 | |
for epoch in range(self.start_epoch, self.epochs + 1): | |
result = self._train_epoch(epoch) | |
# save logged informations into log dict | |
log = {'epoch': epoch} | |
log.update(result) | |
self._record_best(log) | |
# print logged informations to the screen | |
for key, value in log.items(): | |
print('\t{:15s}: {}'.format(str(key), value)) | |
# evaluate model performance according to configured metric, save best checkpoint as model_best | |
best = False | |
if self.mnt_mode != 'off': | |
try: | |
# check whether model performance improved or not, according to specified metric(mnt_metric) | |
improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ | |
(self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) | |
except KeyError: | |
print("Warning: Metric '{}' is not found. " "Model performance monitoring is disabled.".format( | |
self.mnt_metric)) | |
self.mnt_mode = 'off' | |
improved = False | |
if improved: | |
self.mnt_best = log[self.mnt_metric] | |
not_improved_count = 0 | |
best = True | |
else: | |
not_improved_count += 1 | |
if not_improved_count > self.early_stop: | |
print("Validation performance didn\'t improve for {} epochs. " "Training stops.".format( | |
self.early_stop)) | |
break | |
if epoch % self.save_period == 0: | |
self._save_checkpoint(epoch, save_best=best) | |
self._print_best() | |
self._print_best_to_file() | |
def _print_best_to_file(self): | |
crt_time = time.asctime(time.localtime(time.time())) | |
self.best_recorder['val']['time'] = crt_time | |
self.best_recorder['test']['time'] = crt_time | |
self.best_recorder['val']['seed'] = self.args.seed | |
self.best_recorder['test']['seed'] = self.args.seed | |
self.best_recorder['val']['best_model_from'] = 'val' | |
self.best_recorder['test']['best_model_from'] = 'test' | |
if not os.path.exists(self.args.record_dir): | |
os.makedirs(self.args.record_dir) | |
record_path = os.path.join(self.args.record_dir, self.args.dataset_name+'.csv') | |
if not os.path.exists(record_path): | |
record_table = pd.DataFrame() | |
else: | |
record_table = pd.read_csv(record_path) | |
record_table = record_table.append(self.best_recorder['val'], ignore_index=True) | |
record_table = record_table.append(self.best_recorder['test'], ignore_index=True) | |
record_table.to_csv(record_path, index=False) | |
def _prepare_device(self, n_gpu_use): | |
n_gpu = torch.cuda.device_count() | |
if n_gpu_use > 0 and n_gpu == 0: | |
print("Warning: There\'s no GPU available on this machine," "training will be performed on CPU.") | |
n_gpu_use = 0 | |
if n_gpu_use > n_gpu: | |
print( | |
"Warning: The number of GPU\'s configured to use is {}, but only {} are available " "on this machine.".format( | |
n_gpu_use, n_gpu)) | |
n_gpu_use = n_gpu | |
device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') | |
list_ids = list(range(n_gpu_use)) | |
return device, list_ids | |
def _save_checkpoint(self, epoch, save_best=False): | |
state = { | |
'epoch': epoch, | |
'state_dict': self.model.state_dict(), | |
'optimizer': self.optimizer.state_dict(), | |
'monitor_best': self.mnt_best | |
} | |
filename = os.path.join(self.checkpoint_dir, 'current_checkpoint.pth') | |
torch.save(state, filename) | |
print("Saving checkpoint: {} ...".format(filename)) | |
if save_best: | |
best_path = os.path.join(self.checkpoint_dir, 'model_best.pth') | |
torch.save(state, best_path) | |
print("Saving current best: model_best.pth ...") | |
def _resume_checkpoint(self, resume_path): | |
resume_path = str(resume_path) | |
print("Loading checkpoint: {} ...".format(resume_path)) | |
checkpoint = torch.load(resume_path) | |
self.start_epoch = checkpoint['epoch'] + 1 | |
self.mnt_best = checkpoint['monitor_best'] | |
self.model.load_state_dict(checkpoint['state_dict']) | |
self.optimizer.load_state_dict(checkpoint['optimizer']) | |
print("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) | |
def _record_best(self, log): | |
improved_val = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.best_recorder['val'][ | |
self.mnt_metric]) or \ | |
(self.mnt_mode == 'max' and log[self.mnt_metric] >= self.best_recorder['val'][self.mnt_metric]) | |
if improved_val: | |
self.best_recorder['val'].update(log) | |
improved_test = (self.mnt_mode == 'min' and log[self.mnt_metric_test] <= self.best_recorder['test'][ | |
self.mnt_metric_test]) or \ | |
(self.mnt_mode == 'max' and log[self.mnt_metric_test] >= self.best_recorder['test'][ | |
self.mnt_metric_test]) | |
if improved_test: | |
self.best_recorder['test'].update(log) | |
def _print_best(self): | |
print('Best results (w.r.t {}) in validation set:'.format(self.args.monitor_metric)) | |
for key, value in self.best_recorder['val'].items(): | |
print('\t{:15s}: {}'.format(str(key), value)) | |
print('Best results (w.r.t {}) in test set:'.format(self.args.monitor_metric)) | |
for key, value in self.best_recorder['test'].items(): | |
print('\t{:15s}: {}'.format(str(key), value)) | |
if not os.path.exists('valreports/'): | |
os.makedirs('valreports/') | |
if not os.path.exists('testreports/'): | |
os.makedirs('testreports/') | |
class Trainer(BaseTrainer): | |
def __init__(self, model, criterion, metric_ftns, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, | |
test_dataloader): | |
super(Trainer, self).__init__(model, criterion, metric_ftns, optimizer, args) | |
self.lr_scheduler = lr_scheduler | |
self.train_dataloader = train_dataloader | |
self.val_dataloader = val_dataloader | |
self.test_dataloader = test_dataloader | |
def _train_epoch(self, epoch): | |
train_loss = 0 | |
self.model.train() | |
for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.train_dataloader): | |
images, reports_ids, reports_masks = images.to(self.device), reports_ids.to(self.device), reports_masks.to( | |
self.device) | |
output = self.model(images, reports_ids, mode='train') | |
loss = self.criterion(output, reports_ids, reports_masks) | |
train_loss += loss.item() | |
self.optimizer.zero_grad() | |
loss.backward() | |
torch.nn.utils.clip_grad_value_(self.model.parameters(), 0.1) | |
self.optimizer.step() | |
log = {'train_loss': train_loss / len(self.train_dataloader)} | |
self.model.eval() | |
with torch.no_grad(): | |
result_report_val = [] | |
val_gts, val_res = [], [] | |
for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.val_dataloader): | |
images, reports_ids, reports_masks = images.to(self.device), reports_ids.to( | |
self.device), reports_masks.to(self.device) | |
output = self.model(images, mode='sample') | |
reports = self.model.tokenizer.decode_batch(output.cpu().numpy()) | |
for i in range(reports_ids.shape[0]): | |
temp1 = {'reports_ids': images_id[i], 'reports': reports[i]} | |
result_report_val.append(temp1) | |
ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy()) | |
val_res.extend(reports) | |
val_gts.extend(ground_truths) | |
val_met = self.metric_ftns({i: [gt] for i, gt in enumerate(val_gts)}, | |
{i: [re] for i, re in enumerate(val_res)}) | |
log.update(**{'val_' + k: v for k, v in val_met.items()}) | |
resFileval = 'valreports/mixed-' + str(epoch) + '.json' | |
json.dump(result_report_val, open(resFileval, 'w')) | |
self.model.eval() | |
with torch.no_grad(): | |
result_report_test = [] | |
test_gts, test_res = [], [] | |
for batch_idx, (images_id, images, reports_ids, reports_masks) in enumerate(self.test_dataloader): | |
images, reports_ids, reports_masks = images.to(self.device), reports_ids.to( | |
self.device), reports_masks.to(self.device) | |
output = self.model(images, mode='sample') | |
reports = self.model.tokenizer.decode_batch(output.cpu().numpy()) | |
# print('reportsreportsreportsreports',images_id,reports) | |
for i in range(reports_ids.shape[0]): | |
temp = {'reports_ids': images_id[i], 'reports': reports[i]} | |
result_report_test.append(temp) | |
ground_truths = self.model.tokenizer.decode_batch(reports_ids[:, 1:].cpu().numpy()) | |
test_res.extend(reports) | |
test_gts.extend(ground_truths) | |
test_met = self.metric_ftns({i: [gt] for i, gt in enumerate(test_gts)}, | |
{i: [re] for i, re in enumerate(test_res)}) | |
log.update(**{'test_' + k: v for k, v in test_met.items()}) | |
resFiletest = 'testreports/mixed-' + str(epoch) + '.json' | |
json.dump(result_report_test, open(resFiletest, 'w')) | |
self.lr_scheduler.step() | |
return log |