FunSR / models /baselines /dynamic_layers.py
KyanChen's picture
add
02c5426
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.utils import weight_norm
class ScaleAwareAttention2d(nn.Module):
def __init__(self, in_channels, ratios, K, temperature, init_weight=True):
super().__init__()
assert temperature % 3 == 1
self.avgpool = nn.AdaptiveAvgPool2d(1)
if in_channels != 3:
hidden_channels = int(in_channels * ratios) + 1
else:
hidden_channels = K
self.fc1 = nn.Conv2d(in_channels, hidden_channels, 1, bias=False)
# self.bn = nn.BatchNorm2d(hidden_channels)
self.fc2 = nn.Conv2d(hidden_channels + 2, K, 1, bias=True)
self.temperature = temperature
if init_weight:
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def updata_temperature(self):
if self.temperature != 1:
self.temperature -= 3
# print('Change temperature to:', str(self.temperature))
def forward(self, x, scale):
if not self.training:
temperature = 1
else:
temperature = self.temperature
batch_size = x.shape[0]
x = self.avgpool(x)
x = self.fc1(x)
x = F.relu(x)
x = torch.cat(
[x, torch.ones([batch_size, 2, 1, 1], device=x.device) * scale], dim=1
)
x = self.fc2(x).view(x.size(0), -1)
return F.softmax(x / temperature, 1)
class ScaleAwareDynamicConv2d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
ratio=0.25,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
K=4,
temperature=34,
init_weight=True,
):
super().__init__()
assert in_channels % groups == 0
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.bias = bias
self.K = K
self.attention = ScaleAwareAttention2d(in_channels, ratio, K, temperature)
self.weight = nn.Parameter(
torch.randn(
K, out_channels, in_channels // groups, kernel_size, kernel_size
),
requires_grad=True,
)
if bias:
self.bias = nn.Parameter(torch.Tensor(K, out_channels))
else:
self.bias = None
if init_weight:
self._initialize_weights()
def _initialize_weights(self):
for i in range(self.K):
nn.init.kaiming_uniform_(self.weight[i])
def update_temperature(self):
self.attention.updata_temperature()
def forward(self, x, scale):
softmax_attention = self.attention(x, scale)
batch_size, _, height, width = x.size()
x = x.view(1, -1, height, width)
weight = self.weight.view(self.K, -1)
aggregate_weight = torch.mm(softmax_attention, weight).view(
-1, self.in_channels, self.kernel_size, self.kernel_size
)
if self.bias is not None:
aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1)
else:
aggregate_bias = None
output = F.conv2d(
x,
weight=aggregate_weight,
bias=aggregate_bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups * batch_size,
)
output = output.view(
batch_size, self.out_channels, output.size(-2), output.size(-1)
)
return output