|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import pdb |
|
|
|
class attention1d(nn.Module): |
|
def __init__(self, in_planes, ratios, K, temperature, init_weight=True): |
|
super(attention1d, self).__init__() |
|
assert temperature%3==1 |
|
self.avgpool = nn.AdaptiveAvgPool1d(1) |
|
if in_planes!=3: |
|
hidden_planes = int(in_planes*ratios)+1 |
|
else: |
|
hidden_planes = K |
|
self.fc1 = nn.Conv1d(in_planes, hidden_planes, 1, bias=False) |
|
|
|
self.fc2 = nn.Conv1d(hidden_planes, K, 1, bias=True) |
|
self.temperature = temperature |
|
if init_weight: |
|
self._initialize_weights() |
|
|
|
|
|
def _initialize_weights(self): |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv1d): |
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
if isinstance(m ,nn.BatchNorm2d): |
|
nn.init.constant_(m.weight, 1) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def updata_temperature(self): |
|
if self.temperature!=1: |
|
self.temperature -=3 |
|
print('Change temperature to:', str(self.temperature)) |
|
|
|
|
|
def forward(self, x): |
|
x = self.avgpool(x) |
|
x = self.fc1(x) |
|
x = F.relu(x) |
|
x = self.fc2(x).view(x.size(0), -1) |
|
return F.softmax(x/self.temperature, 1) |
|
|
|
|
|
class Dynamic_conv1d(nn.Module): |
|
def __init__(self, in_planes, out_planes, kernel_size, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4,temperature=34, init_weight=True): |
|
super(Dynamic_conv1d, self).__init__() |
|
assert in_planes%groups==0 |
|
self.in_planes = in_planes |
|
self.out_planes = out_planes |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.padding = padding |
|
self.dilation = dilation |
|
self.groups = groups |
|
self.bias = bias |
|
self.K = K |
|
self.attention = attention1d(in_planes, ratio, K, temperature) |
|
|
|
self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes//groups, kernel_size), requires_grad=True) |
|
if bias: |
|
self.bias = nn.Parameter(torch.Tensor(K, out_planes)) |
|
else: |
|
self.bias = None |
|
if init_weight: |
|
self._initialize_weights() |
|
|
|
|
|
def _initialize_weights(self): |
|
for i in range(self.K): |
|
nn.init.kaiming_uniform_(self.weight[i]) |
|
|
|
|
|
def update_temperature(self): |
|
self.attention.updata_temperature() |
|
|
|
def forward(self, x): |
|
softmax_attention = self.attention(x) |
|
batch_size, in_planes, height = x.size() |
|
x = x.view(1, -1, height, ) |
|
weight = self.weight.view(self.K, -1) |
|
|
|
|
|
aggregate_weight = torch.mm(softmax_attention, weight).view(-1, self.in_planes, self.kernel_size,) |
|
if self.bias is not None: |
|
aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1) |
|
output = F.conv1d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups*batch_size) |
|
else: |
|
output = F.conv1d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups * batch_size) |
|
|
|
output = output.view(batch_size, self.out_planes, output.size(-1)) |
|
return output |
|
|
|
|
|
|
|
class attention2d(nn.Module): |
|
def __init__(self, in_planes, ratios, K, temperature, init_weight=True): |
|
super(attention2d, self).__init__() |
|
assert temperature%3==1 |
|
self.avgpool = nn.AdaptiveAvgPool2d(1) |
|
if in_planes!=3: |
|
hidden_planes = int(in_planes*ratios)+1 |
|
else: |
|
hidden_planes = K |
|
self.fc1 = nn.Conv2d(in_planes, hidden_planes, 1, bias=False) |
|
|
|
self.fc2 = nn.Conv2d(hidden_planes, K, 1, bias=True) |
|
self.temperature = temperature |
|
if init_weight: |
|
self._initialize_weights() |
|
|
|
|
|
def _initialize_weights(self): |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
if isinstance(m ,nn.BatchNorm2d): |
|
nn.init.constant_(m.weight, 1) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def updata_temperature(self): |
|
if self.temperature!=1: |
|
self.temperature -=3 |
|
print('Change temperature to:', str(self.temperature)) |
|
|
|
|
|
def forward(self, x): |
|
x = self.avgpool(x) |
|
x = self.fc1(x) |
|
x = F.relu(x) |
|
x = self.fc2(x).view(x.size(0), -1) |
|
return F.softmax(x/self.temperature, 1) |
|
|
|
|
|
class Dynamic_deepwise_conv2d(nn.Module): |
|
def __init__(self, in_planes, out_planes, kernel_size, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4,temperature=34, init_weight=True): |
|
super(Dynamic_deepwise_conv2d, self).__init__() |
|
assert in_planes%groups==0 |
|
self.in_planes = in_planes |
|
self.out_planes = out_planes |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.padding = padding |
|
self.dilation = dilation |
|
self.groups = groups |
|
self.bias = bias |
|
self.K = K |
|
self.attention = attention2d(in_planes, ratio, K, temperature) |
|
|
|
self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes//groups, kernel_size, kernel_size), requires_grad=True) |
|
if bias: |
|
self.bias = nn.Parameter(torch.Tensor(K, out_planes)) |
|
else: |
|
self.bias = None |
|
if init_weight: |
|
self._initialize_weights() |
|
|
|
|
|
def _initialize_weights(self): |
|
for i in range(self.K): |
|
nn.init.kaiming_uniform_(self.weight[i]) |
|
|
|
|
|
def update_temperature(self): |
|
self.attention.updata_temperature() |
|
|
|
def forward(self, x, y): |
|
softmax_attention = self.attention(x) |
|
batch_size, in_planes, height, width = x.size() |
|
y = y.view(1, -1, height, width) |
|
weight = self.weight.view(self.K, -1) |
|
|
|
|
|
aggregate_weight = torch.mm(softmax_attention, weight).view(-1, 1, self.kernel_size, self.kernel_size) |
|
if self.bias is not None: |
|
aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1) |
|
output = F.conv2d(y, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups*batch_size) |
|
else: |
|
output = F.conv2d(y, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups * batch_size) |
|
|
|
output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1)) |
|
return output |
|
|
|
class Dynamic_conv2d(nn.Module): |
|
def __init__(self, in_planes, out_planes, kernel_size, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4,temperature=34, init_weight=True): |
|
super(Dynamic_conv2d, self).__init__() |
|
assert in_planes%groups==0 |
|
self.in_planes = in_planes |
|
self.out_planes = out_planes |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.padding = padding |
|
self.dilation = dilation |
|
self.groups = groups |
|
self.bias = bias |
|
self.K = K |
|
self.attention = attention2d(in_planes, ratio, K, temperature) |
|
|
|
self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes//groups, kernel_size, kernel_size), requires_grad=True) |
|
if bias: |
|
self.bias = nn.Parameter(torch.Tensor(K, out_planes)) |
|
else: |
|
self.bias = None |
|
if init_weight: |
|
self._initialize_weights() |
|
|
|
|
|
def _initialize_weights(self): |
|
for i in range(self.K): |
|
nn.init.kaiming_uniform_(self.weight[i]) |
|
|
|
|
|
def update_temperature(self): |
|
self.attention.updata_temperature() |
|
|
|
def forward(self, x,y): |
|
softmax_attention = self.attention(x) |
|
batch_size, in_planes, height, width = x.size() |
|
y = y.view(1, -1, height, width) |
|
weight = self.weight.view(self.K, -1) |
|
|
|
|
|
aggregate_weight = torch.mm(softmax_attention, weight).view(-1, self.in_planes, self.kernel_size, self.kernel_size) |
|
if self.bias is not None: |
|
aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1) |
|
output = F.conv2d(y, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups*batch_size) |
|
else: |
|
output = F.conv2d(y, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups * batch_size) |
|
|
|
output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1)) |
|
return output |
|
|
|
|
|
class attention3d(nn.Module): |
|
def __init__(self, in_planes, ratios, K, temperature): |
|
super(attention3d, self).__init__() |
|
assert temperature%3==1 |
|
self.avgpool = nn.AdaptiveAvgPool3d(1) |
|
if in_planes != 3: |
|
hidden_planes = int(in_planes * ratios)+1 |
|
else: |
|
hidden_planes = K |
|
self.fc1 = nn.Conv3d(in_planes, hidden_planes, 1, bias=False) |
|
self.fc2 = nn.Conv3d(hidden_planes, K, 1, bias=False) |
|
self.temperature = temperature |
|
|
|
def updata_temperature(self): |
|
if self.temperature!=1: |
|
self.temperature -=3 |
|
print('Change temperature to:', str(self.temperature)) |
|
|
|
def forward(self, x): |
|
x = self.avgpool(x) |
|
x = self.fc1(x) |
|
x = F.relu(x) |
|
x = self.fc2(x).view(x.size(0), -1) |
|
return F.softmax(x / self.temperature, 1) |
|
|
|
class Dynamic_conv3d(nn.Module): |
|
def __init__(self, in_planes, out_planes, kernel_size, ratio=0.25, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4, temperature=34): |
|
super(Dynamic_conv3d, self).__init__() |
|
assert in_planes%groups==0 |
|
self.in_planes = in_planes |
|
self.out_planes = out_planes |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.padding = padding |
|
self.dilation = dilation |
|
self.groups = groups |
|
self.bias = bias |
|
self.K = K |
|
self.attention = attention3d(in_planes, ratio, K, temperature) |
|
|
|
self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes//groups, kernel_size, kernel_size, kernel_size), requires_grad=True) |
|
if bias: |
|
self.bias = nn.Parameter(torch.Tensor(K, out_planes)) |
|
else: |
|
self.bias = None |
|
|
|
|
|
|
|
|
|
|
|
def update_temperature(self): |
|
self.attention.updata_temperature() |
|
|
|
def forward(self, x): |
|
softmax_attention = self.attention(x) |
|
batch_size, in_planes, depth, height, width = x.size() |
|
x = x.view(1, -1, depth, height, width) |
|
weight = self.weight.view(self.K, -1) |
|
|
|
|
|
aggregate_weight = torch.mm(softmax_attention, weight).view(-1, self.in_planes, self.kernel_size, self.kernel_size, self.kernel_size) |
|
if self.bias is not None: |
|
aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1) |
|
output = F.conv3d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups*batch_size) |
|
else: |
|
output = F.conv3d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups * batch_size) |
|
|
|
output = output.view(batch_size, self.out_planes, output.size(-3), output.size(-2), output.size(-1)) |
|
return output |
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
x = torch.randn(12, 256, 64, 64) |
|
y = torch.randn(12, 256, 64, 64) |
|
|
|
model = Dynamic_conv2d(in_planes=256, out_planes=256, kernel_size=3, ratio=0.25, padding=1,groups=1) |
|
x = x.to('cuda:0') |
|
y = y.to('cuda:0') |
|
model.to('cuda') |
|
|
|
print(model(x,y).shape) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|