File size: 4,591 Bytes
57746f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
r""" Hypercorrelation Squeeze training (validation) code """
import argparse
import torch.optim as optim
import torch.nn as nn
import torch
from fewshot_data.model.hsnet import HypercorrSqueezeNetwork
from fewshot_data.common.logger import Logger, AverageMeter
from fewshot_data.common.evaluation import Evaluator
from fewshot_data.common import utils
from fewshot_data.data.dataset import FSSDataset
def train(epoch, model, dataloader, optimizer, training):
r""" Train HSNet """
# Force randomness during training / freeze randomness during testing
utils.fix_randseed(None) if training else utils.fix_randseed(0)
model.module.train_mode() if training else model.module.eval()
average_meter = AverageMeter(dataloader.dataset)
for idx, batch in enumerate(dataloader):
# 1. Hypercorrelation Squeeze Networks forward pass
batch = utils.to_cuda(batch)
logit_mask = model(batch['query_img'], batch['support_imgs'].squeeze(1), batch['support_masks'].squeeze(1))
pred_mask = logit_mask.argmax(dim=1)
# 2. Compute loss & update model parameters
loss = model.module.compute_objective(logit_mask, batch['query_mask'])
if training:
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 3. Evaluate prediction
area_inter, area_union = Evaluator.classify_prediction(pred_mask, batch)
average_meter.update(area_inter, area_union, batch['class_id'], loss.detach().clone())
average_meter.write_process(idx, len(dataloader), epoch, write_batch_idx=50)
# Write evaluation results
average_meter.write_result('Training' if training else 'Validation', epoch)
avg_loss = utils.mean(average_meter.loss_buf)
miou, fb_iou = average_meter.compute_iou()
return avg_loss, miou, fb_iou
if __name__ == '__main__':
# Arguments parsing
parser = argparse.ArgumentParser(description='Hypercorrelation Squeeze Pytorch Implementation')
parser.add_argument('--datapath', type=str, default='fewshot_data/Datasets_HSN')
parser.add_argument('--benchmark', type=str, default='pascal', choices=['pascal', 'coco', 'fss'])
parser.add_argument('--logpath', type=str, default='')
parser.add_argument('--bsz', type=int, default=20)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--niter', type=int, default=2000)
parser.add_argument('--nworker', type=int, default=8)
parser.add_argument('--fold', type=int, default=0, choices=[0, 1, 2, 3])
parser.add_argument('--backbone', type=str, default='resnet101', choices=['vgg16', 'resnet50', 'resnet101'])
args = parser.parse_args()
Logger.initialize(args, training=True)
# Model initialization
model = HypercorrSqueezeNetwork(args.backbone, False)
Logger.log_params(model)
# Device setup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Logger.info('# available GPUs: %d' % torch.cuda.device_count())
model = nn.DataParallel(model)
model.to(device)
# Helper classes (for training) initialization
optimizer = optim.Adam([{"params": model.parameters(), "lr": args.lr}])
Evaluator.initialize()
# Dataset initialization
FSSDataset.initialize(img_size=400, datapath=args.datapath, use_original_imgsize=False)
dataloader_trn = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'trn')
dataloader_val = FSSDataset.build_dataloader(args.benchmark, args.bsz, args.nworker, args.fold, 'val')
# Train HSNet
best_val_miou = float('-inf')
best_val_loss = float('inf')
for epoch in range(args.niter):
trn_loss, trn_miou, trn_fb_iou = train(epoch, model, dataloader_trn, optimizer, training=True)
with torch.no_grad():
val_loss, val_miou, val_fb_iou = train(epoch, model, dataloader_val, optimizer, training=False)
# Save the best model
if val_miou > best_val_miou:
best_val_miou = val_miou
Logger.save_model_miou(model, epoch, val_miou)
Logger.tbd_writer.add_scalars('fewshot_data/data/loss', {'trn_loss': trn_loss, 'val_loss': val_loss}, epoch)
Logger.tbd_writer.add_scalars('fewshot_data/data/miou', {'trn_miou': trn_miou, 'val_miou': val_miou}, epoch)
Logger.tbd_writer.add_scalars('fewshot_data/data/fb_iou', {'trn_fb_iou': trn_fb_iou, 'val_fb_iou': val_fb_iou}, epoch)
Logger.tbd_writer.flush()
Logger.tbd_writer.close()
Logger.info('==================== Finished Training ====================')
|