Spaces:
Build error
Build error
r""" Logging """ | |
import datetime | |
import logging | |
import os | |
from tensorboardX import SummaryWriter | |
import torch | |
class Logger: | |
r""" Writes results of training/testing """ | |
def initialize(cls, args, training): | |
logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S') | |
logpath = args.logpath if training else '_TEST_' + args.load.split('/')[-1].split('.')[0] + logtime | |
if logpath == '': logpath = logtime | |
cls.logpath = os.path.join('logs', logpath + '.log') | |
cls.benchmark = args.benchmark | |
os.makedirs(cls.logpath) | |
logging.basicConfig(filemode='w', | |
filename=os.path.join(cls.logpath, 'log.txt'), | |
level=logging.INFO, | |
format='%(message)s', | |
datefmt='%m-%d %H:%M:%S') | |
# Console log config | |
console = logging.StreamHandler() | |
console.setLevel(logging.INFO) | |
formatter = logging.Formatter('%(message)s') | |
console.setFormatter(formatter) | |
logging.getLogger('').addHandler(console) | |
# Tensorboard writer | |
cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs')) | |
# Log arguments | |
if training: | |
logging.info(':======== Convolutional Hough Matching Networks =========') | |
for arg_key in args.__dict__: | |
logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key]))) | |
logging.info(':========================================================\n') | |
def info(cls, msg): | |
r""" Writes message to .txt """ | |
logging.info(msg) | |
def save_model(cls, model, epoch, val_pck): | |
torch.save(model.state_dict(), os.path.join(cls.logpath, 'pck_best_model.pt')) | |
cls.info('Model saved @%d w/ val. PCK: %5.2f.\n' % (epoch, val_pck)) | |
class AverageMeter: | |
r""" Stores loss, evaluation results, selected layers """ | |
def __init__(self, benchamrk): | |
r""" Constructor of AverageMeter """ | |
self.buffer_keys = ['pck'] | |
self.buffer = {} | |
for key in self.buffer_keys: | |
self.buffer[key] = [] | |
self.loss_buffer = [] | |
def update(self, eval_result, loss=None): | |
for key in self.buffer_keys: | |
self.buffer[key] += eval_result[key] | |
if loss is not None: | |
self.loss_buffer.append(loss) | |
def write_result(self, split, epoch): | |
msg = '\n*** %s ' % split | |
msg += '[@Epoch %02d] ' % epoch | |
if len(self.loss_buffer) > 0: | |
msg += 'Loss: %5.2f ' % (sum(self.loss_buffer) / len(self.loss_buffer)) | |
for key in self.buffer_keys: | |
msg += '%s: %6.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key])) | |
msg += '***\n' | |
Logger.info(msg) | |
def write_process(self, batch_idx, datalen, epoch): | |
msg = '[Epoch: %02d] ' % epoch | |
msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen) | |
if len(self.loss_buffer) > 0: | |
msg += 'Loss: %5.2f ' % self.loss_buffer[-1] | |
msg += 'Avg Loss: %5.5f ' % (sum(self.loss_buffer) / len(self.loss_buffer)) | |
for key in self.buffer_keys: | |
msg += 'Avg %s: %5.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]) * 100) | |
Logger.info(msg) | |
def write_test_process(self, batch_idx, datalen): | |
msg = '[Batch: %04d/%04d] ' % (batch_idx+1, datalen) | |
for key in self.buffer_keys: | |
if key == 'pck': | |
pcks = torch.stack(self.buffer[key]).mean(dim=0) * 100 | |
val = '' | |
for p in pcks: | |
val += '%5.2f ' % p.item() | |
msg += 'Avg %s: %s ' % (key.upper(), val) | |
else: | |
msg += 'Avg %s: %5.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key])) | |
Logger.info(msg) | |
def get_test_result(self): | |
result = {} | |
for key in self.buffer_keys: | |
result[key] = torch.stack(self.buffer[key]).mean(dim=0) * 100 | |
return result | |