Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import copy | |
import inspect | |
from typing import Any, Dict, List, Optional, Tuple, Type, Union | |
import numpy as np | |
import torch as th | |
import torch.nn.functional as thf | |
from torch.nn import init | |
from torch.nn.modules.utils import _pair | |
from torch.nn.utils.weight_norm import remove_weight_norm, WeightNorm | |
fc_default_activation = th.nn.LeakyReLU(0.2, inplace=True) | |
def gaussian_kernel(ksize: int, std: Optional[float] = None) -> np.ndarray: | |
"""Generates numpy array filled in with Gaussian values. | |
The function generates Gaussian kernel (values according to the Gauss distribution) | |
on the grid according to the kernel size. | |
Args: | |
ksize (int): The kernel size, must be odd number larger than 1. Otherwise throws an exception. | |
std (float): The standard deviation, could be None, in which case it will be calculated | |
accordoing to the kernel size. | |
Returns: | |
np.array: The gaussian kernel. | |
""" | |
assert ksize % 2 == 1 | |
radius = ksize // 2 | |
if std is None: | |
std = np.sqrt(-(radius**2) / (2 * np.log(0.05))) | |
x, y = np.meshgrid(np.linspace(-radius, radius, ksize), np.linspace(-radius, radius, ksize)) | |
xy = np.stack([x, y], axis=2) | |
gk = np.exp(-(xy**2).sum(-1) / (2 * std**2)) | |
gk /= gk.sum() | |
return gk | |
class FCLayer(th.nn.Module): | |
# pyre-fixme[2]: Parameter must be annotated. | |
def __init__(self, n_in, n_out, nonlin=fc_default_activation) -> None: | |
super().__init__() | |
self.fc = th.nn.Linear(n_in, n_out, bias=True) | |
# pyre-fixme[4]: Attribute must be annotated. | |
self.nonlin = nonlin if nonlin is not None else lambda x: x | |
self.fc.bias.data.fill_(0) | |
th.nn.init.xavier_uniform_(self.fc.weight.data) | |
# pyre-fixme[3]: Return type must be annotated. | |
# pyre-fixme[2]: Parameter must be annotated. | |
def forward(self, x): | |
x = self.fc(x) | |
x = self.nonlin(x) | |
return x | |
# pyre-fixme[2]: Parameter must be annotated. | |
def check_args_shadowing(name, method: object, arg_names) -> None: | |
spec = inspect.getfullargspec(method) | |
init_args = {*spec.args, *spec.kwonlyargs} | |
for arg_name in arg_names: | |
if arg_name in init_args: | |
raise TypeError(f"{name} attempted to shadow a wrapped argument: {arg_name}") | |
# For backward compatibility. | |
class TensorMappingHook(object): | |
def __init__( | |
self, | |
name_mapping: List[Tuple[str, str]], | |
expected_shape: Optional[Dict[str, List[int]]] = None, | |
) -> None: | |
"""This hook is expected to be used with "_register_load_state_dict_pre_hook" to | |
modify names and tensor shapes in the loaded state dictionary. | |
Args: | |
name_mapping: list of string tuples | |
A list of tuples containing expected names from the state dict and names expected | |
by the module. | |
expected_shape: dict | |
A mapping from parameter names to expected tensor shapes. | |
""" | |
self.name_mapping = name_mapping | |
# pyre-fixme[4]: Attribute must be annotated. | |
self.expected_shape = expected_shape if expected_shape is not None else {} | |
def __call__( | |
self, | |
# pyre-fixme[2]: Parameter must be annotated. | |
state_dict, | |
# pyre-fixme[2]: Parameter must be annotated. | |
prefix, | |
# pyre-fixme[2]: Parameter must be annotated. | |
local_metadata, | |
# pyre-fixme[2]: Parameter must be annotated. | |
strict, | |
# pyre-fixme[2]: Parameter must be annotated. | |
missing_keys, | |
# pyre-fixme[2]: Parameter must be annotated. | |
unexpected_keys, | |
# pyre-fixme[2]: Parameter must be annotated. | |
error_msgs, | |
) -> None: | |
for old_name, new_name in self.name_mapping: | |
if prefix + old_name in state_dict: | |
tensor = state_dict.pop(prefix + old_name) | |
if new_name in self.expected_shape: | |
tensor = tensor.view(*self.expected_shape[new_name]) | |
state_dict[prefix + new_name] = tensor | |
# pyre-fixme[3]: Return type must be annotated. | |
def weight_norm_wrapper( | |
cls: Type[th.nn.Module], | |
new_cls_name: str, | |
name: str = "weight", | |
g_dim: int = 0, | |
v_dim: Optional[int] = 0, | |
): | |
"""Wraps a torch.nn.Module class to support weight normalization. The wrapped class | |
is compatible with the fuse/unfuse syntax and is able to load state dict from previous | |
implementations. | |
Args: | |
cls: Type[th.nn.Module] | |
Class to apply the wrapper to. | |
new_cls_name: str | |
Name of the new class created by the wrapper. This should be the name | |
of whatever variable you assign the result of this function to. Ex: | |
``SomeLayerWN = weight_norm_wrapper(SomeLayer, "SomeLayerWN", ...)`` | |
name: str | |
Name of the parameter to apply weight normalization to. | |
g_dim: int | |
Learnable dimension of the magnitude tensor. Set to None or -1 for single scalar magnitude. | |
Default values for Linear and Conv2d layers are 0s and for ConvTranspose2d layers are 1s. | |
v_dim: int | |
Of which dimension of the direction tensor is calutated independently for the norm. Set to | |
None or -1 for calculating norm over the entire direction tensor (weight tensor). Default | |
values for most of the WN layers are None to preserve the existing behavior. | |
""" | |
class Wrap(cls): | |
def __init__(self, *args: Any, name=name, g_dim=g_dim, v_dim=v_dim, **kwargs: Any): | |
# Check if the extra arguments are overwriting arguments for the wrapped class | |
check_args_shadowing( | |
"weight_norm_wrapper", super().__init__, ["name", "g_dim", "v_dim"] | |
) | |
super().__init__(*args, **kwargs) | |
# Sanitize v_dim since we are hacking the built-in utility to support | |
# a non-standard WeightNorm implementation. | |
if v_dim is None: | |
v_dim = -1 | |
self.weight_norm_args = {"name": name, "g_dim": g_dim, "v_dim": v_dim} | |
self.is_fused = True | |
self.unfuse() | |
# For backward compatibility. | |
self._register_load_state_dict_pre_hook( | |
TensorMappingHook( | |
[(name, name + "_v"), ("g", name + "_g")], | |
{name + "_g": getattr(self, name + "_g").shape}, | |
) | |
) | |
def fuse(self): | |
if self.is_fused: | |
return | |
# Check if the module is frozen. | |
param_name = self.weight_norm_args["name"] + "_g" | |
if hasattr(self, param_name) and param_name not in self._parameters: | |
raise ValueError("Trying to fuse frozen module.") | |
remove_weight_norm(self, self.weight_norm_args["name"]) | |
self.is_fused = True | |
def unfuse(self): | |
if not self.is_fused: | |
return | |
# Check if the module is frozen. | |
param_name = self.weight_norm_args["name"] | |
if hasattr(self, param_name) and param_name not in self._parameters: | |
raise ValueError("Trying to unfuse frozen module.") | |
wn = WeightNorm.apply( | |
self, self.weight_norm_args["name"], self.weight_norm_args["g_dim"] | |
) | |
# Overwrite the dim property to support mismatched norm calculate for v and g tensor. | |
if wn.dim != self.weight_norm_args["v_dim"]: | |
wn.dim = self.weight_norm_args["v_dim"] | |
# Adjust the norm values. | |
weight = getattr(self, self.weight_norm_args["name"] + "_v") | |
norm = getattr(self, self.weight_norm_args["name"] + "_g") | |
norm.data[:] = th.norm_except_dim(weight, 2, wn.dim) | |
self.is_fused = False | |
def __deepcopy__(self, memo): | |
# Delete derived tensor to avoid deepcopy error. | |
if not self.is_fused: | |
delattr(self, self.weight_norm_args["name"]) | |
# Deepcopy. | |
cls = self.__class__ | |
result = cls.__new__(cls) | |
memo[id(self)] = result | |
for k, v in self.__dict__.items(): | |
setattr(result, k, copy.deepcopy(v, memo)) | |
if not self.is_fused: | |
setattr(result, self.weight_norm_args["name"], None) | |
setattr(self, self.weight_norm_args["name"], None) | |
return result | |
# Allows for pickling of the wrapper: https://bugs.python.org/issue13520 | |
Wrap.__qualname__ = new_cls_name | |
return Wrap | |
# pyre-fixme[2]: Parameter must be annotated. | |
def is_weight_norm_wrapped(module) -> bool: | |
for hook in module._forward_pre_hooks.values(): | |
if isinstance(hook, WeightNorm): | |
return True | |
return False | |
class Conv2dUB(th.nn.Conv2d): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
height: int, | |
width: int, | |
# pyre-fixme[2]: Parameter must be annotated. | |
*args, | |
bias: bool = True, | |
# pyre-fixme[2]: Parameter must be annotated. | |
**kwargs, | |
) -> None: | |
"""Conv2d with untied bias.""" | |
super().__init__(in_channels, out_channels, *args, bias=False, **kwargs) | |
# pyre-fixme[4]: Attribute must be annotated. | |
self.bias = th.nn.Parameter(th.zeros(out_channels, height, width)) if bias else None | |
# TODO: remove this method once upgraded to pytorch 1.8 | |
# pyre-fixme[3]: Return type must be annotated. | |
def _conv_forward(self, input: th.Tensor, weight: th.Tensor, bias: Optional[th.Tensor]): | |
# Copied from pt1.8 source code. | |
if self.padding_mode != "zeros": | |
input = thf.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode) | |
return thf.conv2d( | |
input, weight, bias, self.stride, _pair(0), self.dilation, self.groups | |
) | |
return thf.conv2d( | |
input, | |
weight, | |
bias, | |
self.stride, | |
# pyre-fixme[6]: For 5th param expected `Union[List[int], int, Size, | |
# typing.Tuple[int, ...]]` but got `Union[str, typing.Tuple[int, ...]]`. | |
self.padding, | |
self.dilation, | |
self.groups, | |
) | |
def forward(self, input: th.Tensor) -> th.Tensor: | |
output = self._conv_forward(input, self.weight, None) | |
bias = self.bias | |
if bias is not None: | |
# Assertion for jit script. | |
assert bias is not None | |
output = output + bias[None] | |
return output | |
class ConvTranspose2dUB(th.nn.ConvTranspose2d): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
height: int, | |
width: int, | |
# pyre-fixme[2]: Parameter must be annotated. | |
*args, | |
bias: bool = True, | |
# pyre-fixme[2]: Parameter must be annotated. | |
**kwargs, | |
) -> None: | |
"""ConvTranspose2d with untied bias.""" | |
super().__init__(in_channels, out_channels, *args, bias=False, **kwargs) | |
if self.padding_mode != "zeros": | |
raise ValueError("Only `zeros` padding mode is supported for ConvTranspose2dUB") | |
# pyre-fixme[4]: Attribute must be annotated. | |
self.bias = th.nn.Parameter(th.zeros(out_channels, height, width)) if bias else None | |
def forward(self, input: th.Tensor, output_size: Optional[List[int]] = None) -> th.Tensor: | |
# TODO(T111390117): Fix Conv member annotations. | |
output_padding = self._output_padding( | |
input=input, | |
output_size=output_size, | |
# pyre-fixme[6]: For 3rd param expected `List[int]` but got | |
# `Tuple[int, ...]`. | |
stride=self.stride, | |
# pyre-fixme[6]: For 4th param expected `List[int]` but got | |
# `Union[str, typing.Tuple[int, ...]]`. | |
padding=self.padding, | |
# pyre-fixme[6]: For 5th param expected `List[int]` but got | |
# `Tuple[int, ...]`. | |
kernel_size=self.kernel_size, | |
# This is now required as of D35874490 | |
num_spatial_dims=input.dim() - 2, | |
# pyre-fixme[6]: For 6th param expected `Optional[List[int]]` but got | |
# `Tuple[int, ...]`. | |
dilation=self.dilation, | |
) | |
output = thf.conv_transpose2d( | |
input, | |
self.weight, | |
None, | |
self.stride, | |
# pyre-fixme[6]: For 5th param expected `Union[List[int], int, Size, | |
# typing.Tuple[int, ...]]` but got `Union[str, typing.Tuple[int, ...]]`. | |
self.padding, | |
output_padding, | |
self.groups, | |
self.dilation, | |
) | |
bias = self.bias | |
if bias is not None: | |
# Assertion for jit script. | |
assert bias is not None | |
output = output + bias[None] | |
return output | |
# NOTE: This function (on super _ConvTransposeNd) was updated in D35874490 with non-optional | |
# param num_spatial_dims added. Since we need both old/new pytorch versions to work (until those | |
# changes reach DGX), we're simply copying the updated code here until then. | |
# TODO remove this function once updated torch code is released to DGX | |
def _output_padding( | |
self, | |
input: th.Tensor, | |
output_size: Optional[List[int]], | |
stride: List[int], | |
padding: List[int], | |
kernel_size: List[int], | |
num_spatial_dims: int, | |
dilation: Optional[List[int]] = None, | |
) -> List[int]: | |
if output_size is None: | |
# converting to list if was not already | |
ret = th.nn.modules.utils._single(self.output_padding) | |
else: | |
has_batch_dim = input.dim() == num_spatial_dims + 2 | |
num_non_spatial_dims = 2 if has_batch_dim else 1 | |
if len(output_size) == num_non_spatial_dims + num_spatial_dims: | |
output_size = output_size[num_non_spatial_dims:] | |
if len(output_size) != num_spatial_dims: | |
raise ValueError( | |
"ConvTranspose{}D: for {}D input, output_size must have {} or {} elements (got {})".format( | |
num_spatial_dims, | |
input.dim(), | |
num_spatial_dims, | |
num_non_spatial_dims + num_spatial_dims, | |
len(output_size), | |
) | |
) | |
min_sizes = th.jit.annotate(List[int], []) | |
max_sizes = th.jit.annotate(List[int], []) | |
for d in range(num_spatial_dims): | |
dim_size = ( | |
(input.size(d + num_non_spatial_dims) - 1) * stride[d] | |
- 2 * padding[d] | |
+ (dilation[d] if dilation is not None else 1) * (kernel_size[d] - 1) | |
+ 1 | |
) | |
min_sizes.append(dim_size) | |
max_sizes.append(min_sizes[d] + stride[d] - 1) | |
for i in range(len(output_size)): | |
size = output_size[i] | |
min_size = min_sizes[i] | |
max_size = max_sizes[i] | |
if size < min_size or size > max_size: | |
raise ValueError( | |
( | |
"requested an output size of {}, but valid sizes range " | |
"from {} to {} (for an input of {})" | |
).format(output_size, min_sizes, max_sizes, input.size()[2:]) | |
) | |
res = th.jit.annotate(List[int], []) | |
for d in range(num_spatial_dims): | |
res.append(output_size[d] - min_sizes[d]) | |
ret = res | |
return ret | |
# Set default g_dim=0 (Conv2d) or 1 (ConvTranspose2d) and v_dim=None to preserve | |
# the current weight norm behavior. | |
# pyre-fixme[5]: Global expression must be annotated. | |
LinearWN = weight_norm_wrapper(th.nn.Linear, "LinearWN", g_dim=0, v_dim=None) | |
# pyre-fixme[5]: Global expression must be annotated. | |
Conv2dWN = weight_norm_wrapper(th.nn.Conv2d, "Conv2dWN", g_dim=0, v_dim=None) | |
# pyre-fixme[5]: Global expression must be annotated. | |
Conv2dWNUB = weight_norm_wrapper(Conv2dUB, "Conv2dWNUB", g_dim=0, v_dim=None) | |
# pyre-fixme[5]: Global expression must be annotated. | |
ConvTranspose2dWN = weight_norm_wrapper( | |
th.nn.ConvTranspose2d, "ConvTranspose2dWN", g_dim=1, v_dim=None | |
) | |
# pyre-fixme[5]: Global expression must be annotated. | |
ConvTranspose2dWNUB = weight_norm_wrapper( | |
ConvTranspose2dUB, "ConvTranspose2dWNUB", g_dim=1, v_dim=None | |
) | |
class InterpolateHook(object): | |
# pyre-fixme[2]: Parameter must be annotated. | |
def __init__(self, size=None, scale_factor=None, mode: str = "bilinear") -> None: | |
"""An object storing options for interpolate function""" | |
# pyre-fixme[4]: Attribute must be annotated. | |
self.size = size | |
# pyre-fixme[4]: Attribute must be annotated. | |
self.scale_factor = scale_factor | |
self.mode = mode | |
# pyre-fixme[3]: Return type must be annotated. | |
# pyre-fixme[2]: Parameter must be annotated. | |
def __call__(self, module, x): | |
assert len(x) == 1, "Module should take only one input for the forward method." | |
return thf.interpolate( | |
x[0], | |
size=self.size, | |
scale_factor=self.scale_factor, | |
mode=self.mode, | |
align_corners=False, | |
) | |
# pyre-fixme[3]: Return type must be annotated. | |
def interpolate_wrapper(cls: Type[th.nn.Module], new_cls_name: str): | |
"""Wraps a torch.nn.Module class and perform additional interpolation on the | |
first and only positional input of the forward method. | |
Args: | |
cls: Type[th.nn.Module] | |
Class to apply the wrapper to. | |
new_cls_name: str | |
Name of the new class created by the wrapper. This should be the name | |
of whatever variable you assign the result of this function to. Ex: | |
``UpConv = interpolate_wrapper(Conv, "UpConv", ...)`` | |
""" | |
class Wrap(cls): | |
def __init__( | |
self, *args: Any, size=None, scale_factor=None, mode="bilinear", **kwargs: Any | |
): | |
check_args_shadowing( | |
"interpolate_wrapper", super().__init__, ["size", "scale_factor", "mode"] | |
) | |
super().__init__(*args, **kwargs) | |
self.register_forward_pre_hook( | |
InterpolateHook(size=size, scale_factor=scale_factor, mode=mode) | |
) | |
# Allows for pickling of the wrapper: https://bugs.python.org/issue13520 | |
Wrap.__qualname__ = new_cls_name | |
return Wrap | |
# pyre-fixme[5]: Global expression must be annotated. | |
UpConv2d = interpolate_wrapper(th.nn.Conv2d, "UpConv2d") | |
# pyre-fixme[5]: Global expression must be annotated. | |
UpConv2dWN = interpolate_wrapper(Conv2dWN, "UpConv2dWN") | |
# pyre-fixme[5]: Global expression must be annotated. | |
UpConv2dWNUB = interpolate_wrapper(Conv2dWNUB, "UpConv2dWNUB") | |
class GlobalAvgPool(th.nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
# pyre-fixme[3]: Return type must be annotated. | |
# pyre-fixme[2]: Parameter must be annotated. | |
def forward(self, x): | |
return x.view(x.shape[0], x.shape[1], -1).mean(dim=2) | |
class Upsample(th.nn.Module): | |
def __init__(self, *args: Any, **kwargs: Any) -> None: | |
super().__init__() | |
# pyre-fixme[4]: Attribute must be annotated. | |
self.args = args | |
# pyre-fixme[4]: Attribute must be annotated. | |
self.kwargs = kwargs | |
# pyre-fixme[3]: Return type must be annotated. | |
# pyre-fixme[2]: Parameter must be annotated. | |
def forward(self, x): | |
return thf.interpolate(x, *self.args, **self.kwargs) | |
class DenseAffine(th.nn.Module): | |
# Per-pixel affine transform layer. | |
# pyre-fixme[2]: Parameter must be annotated. | |
def __init__(self, shape) -> None: | |
super().__init__() | |
self.W = th.nn.Parameter(th.ones(*shape)) | |
self.b = th.nn.Parameter(th.zeros(*shape)) | |
# pyre-fixme[3]: Return type must be annotated. | |
# pyre-fixme[2]: Parameter must be annotated. | |
def forward(self, x, scale=None, crop=None): | |
W = self.W | |
b = self.b | |
if scale is not None: | |
W = thf.interpolate(W, scale_factor=scale, mode="bilinear") | |
b = thf.interpolate(b, scale_factor=scale, mode="bilinear") | |
if crop is not None: | |
W = W[..., crop[0] : crop[1], crop[2] : crop[3]] | |
b = b[..., crop[0] : crop[1], crop[2] : crop[3]] | |
return x * W + b | |
def glorot(m: th.nn.Module, alpha: float = 1.0) -> None: | |
gain = np.sqrt(2.0 / (1.0 + alpha**2)) | |
if isinstance(m, th.nn.Conv2d): | |
ksize = m.kernel_size[0] * m.kernel_size[1] | |
n1 = m.in_channels | |
n2 = m.out_channels | |
std = gain * np.sqrt(2.0 / ((n1 + n2) * ksize)) | |
elif isinstance(m, th.nn.ConvTranspose2d): | |
ksize = m.kernel_size[0] * m.kernel_size[1] // 4 | |
n1 = m.in_channels | |
n2 = m.out_channels | |
std = gain * np.sqrt(2.0 / ((n1 + n2) * ksize)) | |
elif isinstance(m, th.nn.ConvTranspose3d): | |
ksize = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] // 8 | |
n1 = m.in_channels | |
n2 = m.out_channels | |
std = gain * np.sqrt(2.0 / ((n1 + n2) * ksize)) | |
elif isinstance(m, th.nn.Linear): | |
n1 = m.in_features | |
n2 = m.out_features | |
std = gain * np.sqrt(2.0 / (n1 + n2)) | |
else: | |
return | |
is_wnw = is_weight_norm_wrapped(m) | |
if is_wnw: | |
m.fuse() | |
m.weight.data.uniform_(-std * np.sqrt(3.0), std * np.sqrt(3.0)) | |
if m.bias is not None: | |
m.bias.data.zero_() | |
if isinstance(m, th.nn.ConvTranspose2d): | |
# hardcoded for stride=2 for now | |
m.weight.data[:, :, 0::2, 1::2] = m.weight.data[:, :, 0::2, 0::2] | |
m.weight.data[:, :, 1::2, 0::2] = m.weight.data[:, :, 0::2, 0::2] | |
m.weight.data[:, :, 1::2, 1::2] = m.weight.data[:, :, 0::2, 0::2] | |
if is_wnw: | |
m.unfuse() | |
def make_tuple(x: Union[int, Tuple[int, int]], n: int) -> Tuple[int, int]: | |
if isinstance(x, int): | |
return tuple([x for _ in range(n)]) | |
else: | |
return x | |
class LinearELR(th.nn.Module): | |
def __init__( | |
self, | |
in_features: int, | |
out_features: int, | |
bias: bool = True, | |
gain: Optional[float] = None, | |
lr_mul: float = 1.0, | |
bias_lr_mul: Optional[float] = None, | |
) -> None: | |
super(LinearELR, self).__init__() | |
self.in_features = in_features | |
self.weight = th.nn.Parameter(th.zeros(out_features, in_features, dtype=th.float32)) | |
if bias: | |
self.bias: th.nn.Parameter = th.nn.Parameter(th.zeros(out_features, dtype=th.float32)) | |
else: | |
self.register_parameter("bias", None) | |
self.std: float = 0.0 | |
if gain is None: | |
self.gain: float = np.sqrt(2.0) | |
else: | |
self.gain: float = gain | |
self.lr_mul = lr_mul | |
if bias_lr_mul is None: | |
bias_lr_mul = lr_mul | |
self.bias_lr_mul = bias_lr_mul | |
self.reset_parameters() | |
def reset_parameters(self) -> None: | |
self.std = self.gain / np.sqrt(self.in_features) * self.lr_mul | |
init.normal_(self.weight, mean=0, std=1.0 / self.lr_mul) | |
if self.bias is not None: | |
with th.no_grad(): | |
self.bias.zero_() | |
def forward(self, x: th.Tensor) -> th.Tensor: | |
bias = self.bias | |
if bias is not None: | |
bias = bias * self.bias_lr_mul | |
return thf.linear(x, self.weight.mul(self.std), bias) | |
class Conv2dELR(th.nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: Union[int, Tuple[int, int]], | |
stride: Union[int, Tuple[int, int]] = 1, | |
padding: Union[int, Tuple[int, int]] = 0, | |
output_padding: Union[int, Tuple[int, int]] = 0, | |
dilation: Union[int, Tuple[int, int]] = 1, | |
groups: int = 1, | |
bias: bool = True, | |
untied: bool = False, | |
height: int = 1, | |
width: int = 1, | |
gain: Optional[float] = None, | |
transpose: bool = False, | |
fuse_box_filter: bool = False, | |
lr_mul: float = 1.0, | |
bias_lr_mul: Optional[float] = None, | |
) -> None: | |
super().__init__() | |
if in_channels % groups != 0: | |
raise ValueError("in_channels must be divisible by groups") | |
if out_channels % groups != 0: | |
raise ValueError("out_channels must be divisible by groups") | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.kernel_size: Tuple[int, int] = make_tuple(kernel_size, 2) | |
self.stride: Tuple[int, int] = make_tuple(stride, 2) | |
self.padding: Tuple[int, int] = make_tuple(padding, 2) | |
self.output_padding: Tuple[int, int] = make_tuple(output_padding, 2) | |
self.dilation: Tuple[int, int] = make_tuple(dilation, 2) | |
self.groups = groups | |
if gain is None: | |
self.gain: float = np.sqrt(2.0) | |
else: | |
self.gain: float = gain | |
self.lr_mul = lr_mul | |
if bias_lr_mul is None: | |
bias_lr_mul = lr_mul | |
self.bias_lr_mul = bias_lr_mul | |
self.transpose = transpose | |
self.fan_in: float = np.prod(self.kernel_size) * in_channels // groups | |
self.fuse_box_filter = fuse_box_filter | |
if transpose: | |
self.weight: th.nn.Parameter = th.nn.Parameter( | |
th.zeros(in_channels, out_channels // groups, *self.kernel_size, dtype=th.float32) | |
) | |
else: | |
self.weight: th.nn.Parameter = th.nn.Parameter( | |
th.zeros(out_channels, in_channels // groups, *self.kernel_size, dtype=th.float32) | |
) | |
if bias: | |
if untied: | |
self.bias: th.nn.Parameter = th.nn.Parameter( | |
th.zeros(out_channels, height, width, dtype=th.float32) | |
) | |
else: | |
self.bias: th.nn.Parameter = th.nn.Parameter( | |
th.zeros(out_channels, dtype=th.float32) | |
) | |
else: | |
self.register_parameter("bias", None) | |
self.untied = untied | |
self.std: float = 0.0 | |
self.reset_parameters() | |
def reset_parameters(self) -> None: | |
self.std = self.gain / np.sqrt(self.fan_in) * self.lr_mul | |
init.normal_(self.weight, mean=0, std=1.0 / self.lr_mul) | |
if self.bias is not None: | |
with th.no_grad(): | |
self.bias.zero_() | |
def forward(self, x: th.Tensor) -> th.Tensor: | |
if self.transpose: | |
w = self.weight | |
if self.fuse_box_filter: | |
w = thf.pad(w, (1, 1, 1, 1), mode="constant") | |
w = w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1] | |
bias = self.bias | |
if bias is not None: | |
bias = bias * self.bias_lr_mul | |
out = thf.conv_transpose2d( | |
x, | |
w * self.std, | |
bias if not self.untied else None, | |
stride=self.stride, | |
padding=self.padding, | |
output_padding=self.output_padding, | |
dilation=self.dilation, | |
groups=self.groups, | |
) | |
if self.untied and bias is not None: | |
out = out + bias[None, ...] | |
return out | |
else: | |
w = self.weight | |
if self.fuse_box_filter: | |
w = thf.pad(w, (1, 1, 1, 1), mode="constant") | |
w = ( | |
w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1] | |
) * 0.25 | |
bias = self.bias | |
if bias is not None: | |
bias = bias * self.bias_lr_mul | |
out = thf.conv2d( | |
x, | |
w * self.std, | |
bias if not self.untied else None, | |
stride=self.stride, | |
padding=self.padding, | |
dilation=self.dilation, | |
groups=self.groups, | |
) | |
if self.untied and bias is not None: | |
out = out + bias[None, ...] | |
return out | |
class ConcatPyramid(th.nn.Module): | |
def __init__( | |
self, | |
# pyre-fixme[2]: Parameter must be annotated. | |
branch, | |
# pyre-fixme[2]: Parameter must be annotated. | |
n_concat_in, | |
every_other: bool = True, | |
ksize: int = 7, | |
# pyre-fixme[2]: Parameter must be annotated. | |
kstd=None, | |
transposed: bool = False, | |
) -> None: | |
"""Module which wraps an up/down conv branch taking one input X and | |
converts it into a branch which takes two inputs X, Y. At each layer of | |
the original branch, we concatenate the previous output and Y, | |
up/downsampling Y appropriately, before running the layer. | |
Args: | |
branch: th.nn.Sequential or th.nn.ModuleList | |
A branch containing up/down convs, optionally separated by nonlinearities. | |
n_concat_in: int | |
Number of channels in the to-be-concatenated input (Y). | |
every_other: bool | |
If every other layer is a nonlinearity, set this flag. Default is on. | |
ksize: int | |
Kernel size for the Gaussian blur used to downsample each step of the pyramid. | |
kstd: int | |
Kernel std. dev. for the Gaussian blur used to downsample each step of the pyramid. | |
If None, it is determined automatically. | |
transposed: bool | |
Whether or not the conv stack contains transposed convolutions or not. | |
""" | |
super().__init__() | |
assert isinstance(branch, (th.nn.Sequential, th.nn.ModuleList)) | |
# pyre-fixme[4]: Attribute must be annotated. | |
self.branch = branch | |
# pyre-fixme[4]: Attribute must be annotated. | |
self.n_concat_in = n_concat_in | |
self.every_other = every_other | |
self.ksize = ksize | |
# pyre-fixme[4]: Attribute must be annotated. | |
self.kstd = kstd | |
self.transposed = transposed | |
if every_other: | |
# pyre-fixme[4]: Attribute must be annotated. | |
self.levels = int(np.ceil(len(branch) / 2)) | |
else: | |
self.levels = len(branch) | |
kernel = th.from_numpy(gaussian_kernel(ksize, kstd)).float() | |
self.register_buffer("blur_kernel", kernel[None, None].expand(n_concat_in, -1, -1, -1)) | |
# pyre-fixme[3]: Return type must be annotated. | |
# pyre-fixme[2]: Parameter must be annotated. | |
def forward(self, x, y): | |
if self.transposed: | |
blurred = thf.conv2d( | |
y, self.blur_kernel, groups=self.n_concat_in, padding=self.ksize // 2 | |
) | |
pyramid = [blurred[:, :, ::2, ::2]] | |
else: | |
pyramid = [y] | |
for _ in range(self.levels - 1): | |
blurred = thf.conv2d( | |
pyramid[0], self.blur_kernel, groups=self.n_concat_in, padding=self.ksize // 2 | |
) | |
pyramid.insert(0, blurred[:, :, ::2, ::2]) | |
out = x | |
for i, layer in enumerate(self.branch): | |
if (i % 2) == 0 or not self.every_other: | |
idx = i // 2 if self.every_other else i | |
out = th.cat([out, pyramid[idx]], dim=1) | |
out = layer(out) | |
return out | |
# From paper "Making Convolutional Networks Shift-Invariant Again" | |
# https://richzhang.github.io/antialiased-cnns/ | |
# pyre-fixme[3]: Return type must be annotated. | |
# pyre-fixme[2]: Parameter must be annotated. | |
def get_pad_layer(pad_type): | |
if pad_type in ["refl", "reflect"]: | |
PadLayer = th.nn.ReflectionPad2d | |
elif pad_type in ["repl", "replicate"]: | |
PadLayer = th.nn.ReplicationPad2d | |
elif pad_type == "zero": | |
PadLayer = th.nn.ZeroPad2d | |
else: | |
print("Pad type [%s] not recognized" % pad_type) | |
# pyre-fixme[61]: `PadLayer` is undefined, or not always defined. | |
return PadLayer | |
class Downsample(th.nn.Module): | |
# pyre-fixme[3]: Return type must be annotated. | |
# pyre-fixme[2]: Parameter must be annotated. | |
def __init__(self, pad_type="reflect", filt_size=3, stride=2, channels=None, pad_off=0): | |
super(Downsample, self).__init__() | |
# pyre-fixme[4]: Attribute must be annotated. | |
self.filt_size = filt_size | |
# pyre-fixme[4]: Attribute must be annotated. | |
self.pad_off = pad_off | |
# pyre-fixme[4]: Attribute must be annotated. | |
self.pad_sizes = [ | |
int(1.0 * (filt_size - 1) / 2), | |
int(np.ceil(1.0 * (filt_size - 1) / 2)), | |
int(1.0 * (filt_size - 1) / 2), | |
int(np.ceil(1.0 * (filt_size - 1) / 2)), | |
] | |
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] | |
# pyre-fixme[4]: Attribute must be annotated. | |
self.stride = stride | |
self.off = int((self.stride - 1) / 2.0) | |
# pyre-fixme[4]: Attribute must be annotated. | |
self.channels = channels | |
# print('Filter size [%i]'%filt_size) | |
if self.filt_size == 1: | |
a = np.array( | |
[ | |
1.0, | |
] | |
) | |
elif self.filt_size == 2: | |
a = np.array([1.0, 1.0]) | |
elif self.filt_size == 3: | |
a = np.array([1.0, 2.0, 1.0]) | |
elif self.filt_size == 4: | |
a = np.array([1.0, 3.0, 3.0, 1.0]) | |
elif self.filt_size == 5: | |
a = np.array([1.0, 4.0, 6.0, 4.0, 1.0]) | |
elif self.filt_size == 6: | |
a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0]) | |
elif self.filt_size == 7: | |
a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0]) | |
filt = th.Tensor(a[:, None] * a[None, :]) | |
filt = filt / th.sum(filt) | |
self.register_buffer("filt", filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) | |
# pyre-fixme[4]: Attribute must be annotated. | |
self.pad = get_pad_layer(pad_type)(self.pad_sizes) | |
# pyre-fixme[3]: Return type must be annotated. | |
# pyre-fixme[2]: Parameter must be annotated. | |
def forward(self, inp): | |
if self.filt_size == 1: | |
if self.pad_off == 0: | |
return inp[:, :, :: self.stride, :: self.stride] | |
else: | |
return self.pad(inp)[:, :, :: self.stride, :: self.stride] | |
else: | |
return th.nn.functional.conv2d( | |
self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1] | |
) | |