Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
from functools import partial | |
import torch | |
import torch.nn as nn | |
from imaginaire.layers import Conv2dBlock | |
class NonLocal2dBlock(nn.Module): | |
r"""Self attention Layer | |
Args: | |
in_channels (int): Number of channels in the input tensor. | |
scale (bool, optional, default=True): If ``True``, scale the | |
output by a learnable parameter. | |
clamp (bool, optional, default=``False``): If ``True``, clamp the | |
scaling parameter to (-1, 1). | |
weight_norm_type (str, optional, default='none'): | |
Type of weight normalization. | |
``'none'``, ``'spectral'``, ``'weight'``. | |
weight_norm_params (obj, optional, default=None): | |
Parameters of weight normalization. | |
If not ``None``, weight_norm_params.__dict__ will be used as | |
keyword arguments when initializing weight normalization. | |
bias (bool, optional, default=True): If ``True``, adds bias in the | |
convolutional blocks. | |
""" | |
def __init__(self, | |
in_channels, | |
scale=True, | |
clamp=False, | |
weight_norm_type='none', | |
weight_norm_params=None, | |
bias=True): | |
super(NonLocal2dBlock, self).__init__() | |
self.clamp = clamp | |
self.gamma = nn.Parameter(torch.zeros(1)) if scale else 1.0 | |
self.in_channels = in_channels | |
base_conv2d_block = partial(Conv2dBlock, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
weight_norm_type=weight_norm_type, | |
weight_norm_params=weight_norm_params, | |
bias=bias) | |
self.theta = base_conv2d_block(in_channels, in_channels // 8) | |
self.phi = base_conv2d_block(in_channels, in_channels // 8) | |
self.g = base_conv2d_block(in_channels, in_channels // 2) | |
self.out_conv = base_conv2d_block(in_channels // 2, in_channels) | |
self.softmax = nn.Softmax(dim=-1) | |
self.max_pool = nn.MaxPool2d(2) | |
def forward(self, x): | |
r""" | |
Args: | |
x (tensor) : input feature maps (B X C X W X H) | |
Returns: | |
(tuple): | |
- out (tensor) : self attention value + input feature | |
- attention (tensor): B x N x N (N is Width*Height) | |
""" | |
n, c, h, w = x.size() | |
theta = self.theta(x).view(n, -1, h * w).permute(0, 2, 1) | |
phi = self.phi(x) | |
phi = self.max_pool(phi).view(n, -1, h * w // 4) | |
energy = torch.bmm(theta, phi) | |
attention = self.softmax(energy) | |
g = self.g(x) | |
g = self.max_pool(g).view(n, -1, h * w // 4) | |
out = torch.bmm(g, attention.permute(0, 2, 1)) | |
out = out.view(n, c // 2, h, w) | |
out = self.out_conv(out) | |
if self.clamp: | |
out = self.gamma.clamp(-1, 1) * out + x | |
else: | |
out = self.gamma * out + x | |
return out | |