import torch import torch.nn as nn from torch.nn import functional as F from torch.nn.utils import weight_norm class ScaleAwareAttention2d(nn.Module): def __init__(self, in_channels, ratios, K, temperature, init_weight=True): super().__init__() assert temperature % 3 == 1 self.avgpool = nn.AdaptiveAvgPool2d(1) if in_channels != 3: hidden_channels = int(in_channels * ratios) + 1 else: hidden_channels = K self.fc1 = nn.Conv2d(in_channels, hidden_channels, 1, bias=False) # self.bn = nn.BatchNorm2d(hidden_channels) self.fc2 = nn.Conv2d(hidden_channels + 2, K, 1, bias=True) self.temperature = temperature if init_weight: self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) if isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def updata_temperature(self): if self.temperature != 1: self.temperature -= 3 # print('Change temperature to:', str(self.temperature)) def forward(self, x, scale): if not self.training: temperature = 1 else: temperature = self.temperature batch_size = x.shape[0] x = self.avgpool(x) x = self.fc1(x) x = F.relu(x) x = torch.cat( [x, torch.ones([batch_size, 2, 1, 1], device=x.device) * scale], dim=1 ) x = self.fc2(x).view(x.size(0), -1) return F.softmax(x / temperature, 1) class ScaleAwareDynamicConv2d(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4, temperature=34, init_weight=True, ): super().__init__() assert in_channels % groups == 0 self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups self.bias = bias self.K = K self.attention = ScaleAwareAttention2d(in_channels, ratio, K, temperature) self.weight = nn.Parameter( torch.randn( K, out_channels, in_channels // groups, kernel_size, kernel_size ), requires_grad=True, ) if bias: self.bias = nn.Parameter(torch.Tensor(K, out_channels)) else: self.bias = None if init_weight: self._initialize_weights() def _initialize_weights(self): for i in range(self.K): nn.init.kaiming_uniform_(self.weight[i]) def update_temperature(self): self.attention.updata_temperature() def forward(self, x, scale): softmax_attention = self.attention(x, scale) batch_size, _, height, width = x.size() x = x.view(1, -1, height, width) weight = self.weight.view(self.K, -1) aggregate_weight = torch.mm(softmax_attention, weight).view( -1, self.in_channels, self.kernel_size, self.kernel_size ) if self.bias is not None: aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1) else: aggregate_bias = None output = F.conv2d( x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups * batch_size, ) output = output.view( batch_size, self.out_channels, output.size(-2), output.size(-1) ) return output