DeepLearning101's picture
Upload 16 files
b6c45cb verified
raw
history blame contribute delete
4.52 kB
import torch
from torch import nn
from collections import OrderedDict
def to_cuda(tensors): # pragma: no cover (No CUDA on travis)
"""Transfer tensor, dict or list of tensors to GPU.
Args:
tensors (:class:`torch.Tensor`, list or dict): May be a single, a
list or a dictionary of tensors.
Returns:
:class:`torch.Tensor`:
Same as input but transferred to cuda. Goes through lists and dicts
and transfers the torch.Tensor to cuda. Leaves the rest untouched.
"""
if isinstance(tensors, torch.Tensor):
return tensors.cuda()
if isinstance(tensors, list):
return [to_cuda(tens) for tens in tensors]
if isinstance(tensors, dict):
for key in tensors.keys():
tensors[key] = to_cuda(tensors[key])
return tensors
raise TypeError(
"tensors must be a tensor or a list or dict of tensors. "
" Got tensors of type {}".format(type(tensors))
)
def tensors_to_device(tensors, device):
"""Transfer tensor, dict or list of tensors to device.
Args:
tensors (:class:`torch.Tensor`): May be a single, a list or a
dictionary of tensors.
device (:class: `torch.device`): the device where to place the tensors.
Returns:
Union [:class:`torch.Tensor`, list, tuple, dict]:
Same as input but transferred to device.
Goes through lists and dicts and transfers the torch.Tensor to
device. Leaves the rest untouched.
"""
if isinstance(tensors, torch.Tensor):
return tensors.to(device)
elif isinstance(tensors, (list, tuple)):
return [tensors_to_device(tens, device) for tens in tensors]
elif isinstance(tensors, dict):
for key in tensors.keys():
tensors[key] = tensors_to_device(tensors[key], device)
return tensors
else:
return tensors
def pad_x_to_y(x, y, axis=-1):
"""Pad first argument to have same size as second argument
Args:
x (torch.Tensor): Tensor to be padded.
y (torch.Tensor): Tensor to pad x to.
axis (int): Axis to pad on.
Returns:
torch.Tensor, x padded to match y's shape.
"""
if axis != -1:
raise NotImplementedError
inp_len = y.size(axis)
output_len = x.size(axis)
return nn.functional.pad(x, [0, inp_len - output_len])
def load_state_dict_in(state_dict, model):
"""Strictly loads state_dict in model, or the next submodel.
Useful to load standalone model after training it with System.
Args:
state_dict (OrderedDict): the state_dict to load.
model (torch.nn.Module): the model to load it into
Returns:
torch.nn.Module: model with loaded weights.
# .. note:: Keys in a state_dict look like object1.object2.layer_name.weight.etc
We first try to load the model in the classic way.
If this fail we removes the first left part of the key to obtain
object2.layer_name.weight.etc.
Blindly loading with strictly=False should be done with some logging
of the missing keys in the state_dict and the model.
"""
try:
# This can fail if the model was included into a bigger nn.Module
# object. For example, into System.
model.load_state_dict(state_dict, strict=True)
except RuntimeError:
# keys look like object1.object2.layer_name.weight.etc
# The following will remove the first left part of the key to obtain
# object2.layer_name.weight.etc.
# Blindly loading with strictly=False should be done with some
# new_state_dict of the missing keys in the state_dict and the model.
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_k = k[k.find(".") + 1 :]
new_state_dict[new_k] = v
model.load_state_dict(new_state_dict, strict=True)
return model
def are_models_equal(model1, model2):
"""Check for weights equality between models.
Args:
model1 (nn.Module): model instance to be compared.
model2 (nn.Module): second model instance to be compared.
Returns:
bool: Whether all model weights are equal.
"""
for p1, p2 in zip(model1.parameters(), model2.parameters()):
if p1.data.ne(p2.data).sum() > 0:
return False
return True