unknown
'torch_utils.py'
81fb07b
raw
history blame contribute delete
3.6 kB
from collections import OrderedDict
import os
import numpy as np
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
sys.path.append('..')
try:
import data
except:
import foleycrafter.models.specvqgan.onset_baseline.data
# ---------------------------------------------------- #
def load_model(cp_path, net, device=None, strict=True):
if not device:
device = torch.device('cpu')
if os.path.isfile(cp_path):
print("=> loading checkpoint '{}'".format(cp_path))
checkpoint = torch.load(cp_path, map_location=device)
# check if there is module
if list(checkpoint['state_dict'].keys())[0][:7] == 'module.':
state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
name = k[7:]
state_dict[name] = v
else:
state_dict = checkpoint['state_dict']
net.load_state_dict(state_dict, strict=strict)
print("=> loaded checkpoint '{}' (epoch {})"
.format(cp_path, checkpoint['epoch']))
start_epoch = checkpoint['epoch']
else:
print("=> no checkpoint found at '{}'".format(cp_path))
start_epoch = 0
sys.exit()
return net, start_epoch
# ---------------------------------------------------- #
def binary_acc(pred, target, thred):
pred = pred > thred
acc = np.sum(pred == target) / target.shape[0]
return acc
def calc_acc(prob, labels, k):
pred = torch.argsort(prob, dim=-1, descending=True)[..., :k]
top_k_acc = torch.sum(pred == labels.view(-1, 1)).float() / labels.size(0)
return top_k_acc
# ---------------------------------------------------- #
def get_dataloader(args, pr, split='train', shuffle=False, drop_last=False, batch_size=None):
data_loader = getattr(data, pr.dataloader)
if split == 'train':
read_list = pr.list_train
elif split == 'val':
read_list = pr.list_val
elif split == 'test':
read_list = pr.list_test
dataset = data_loader(args, pr, read_list, split=split)
batch_size = batch_size if batch_size else args.batch_size
dataset.getitem_test(1)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=args.num_workers,
pin_memory=True,
drop_last=drop_last)
return dataset, loader
# ---------------------------------------------------- #
def make_optimizer(model, args):
'''
Args:
model: NN to train
Returns:
optimizer: pytorch optmizer for updating the given model parameters.
'''
if args.optim == 'SGD':
optimizer = torch.optim.SGD(
filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov=False
)
elif args.optim == 'Adam':
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr,
weight_decay=args.weight_decay,
)
return optimizer
def adjust_learning_rate(optimizer, epoch, args):
"""Decay the learning rate based on schedule"""
lr = args.lr
if args.schedule == 'cos': # cosine lr schedule
lr *= 0.5 * (1. + np.cos(np.pi * epoch / args.epochs))
elif args.schedule == 'none': # no lr schedule
lr = args.lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr