KyanChen's picture
Upload 1861 files
3b96cb1
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.models.utils import resize
from opencd.registry import MODELS
class BAM(nn.Module):
""" Basic self-attention module
"""
def __init__(self, in_dim, ds=8, activation=nn.ReLU):
super(BAM, self).__init__()
self.chanel_in = in_dim
self.key_channel = self.chanel_in // 8
self.activation = activation
self.ds = ds #
self.pool = nn.AvgPool2d(self.ds)
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1) #
def forward(self, input):
"""
inputs :
x : input feature maps( B X C X W X H)
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
x = self.pool(input)
m_batchsize, C, width, height = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X C X (N)/(ds*ds)
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)/(ds*ds)
energy = torch.bmm(proj_query, proj_key) # transpose check
energy = (self.key_channel ** -.5) * energy
attention = self.softmax(energy) # BX (N) X (N)/(ds*ds)/(ds*ds)
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, width, height)
out = F.interpolate(out, [width * self.ds, height * self.ds])
out = out + input
return out
class _PAMBlock(nn.Module):
'''
The basic implementation for self-attention block/non-local block
Input/Output:
N * C * H * (2*W)
Parameters:
in_channels : the dimension of the input feature map
key_channels : the dimension after the key/query transform
value_channels : the dimension after the value transform
scale : choose the scale to partition the input feature maps
ds : downsampling scale
'''
def __init__(self, in_channels, key_channels, value_channels, scale=1, ds=1):
super(_PAMBlock, self).__init__()
self.scale = scale
self.ds = ds
self.pool = nn.AvgPool2d(self.ds)
self.in_channels = in_channels
self.key_channels = key_channels
self.value_channels = value_channels
self.f_key = nn.Sequential(
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.key_channels)
)
self.f_query = nn.Sequential(
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.key_channels)
)
self.f_value = nn.Conv2d(in_channels=self.in_channels, out_channels=self.value_channels,
kernel_size=1, stride=1, padding=0)
def forward(self, input):
x = input
if self.ds != 1:
x = self.pool(input)
# input shape: b,c,h,2w
batch_size, c, h, w = x.size(0), x.size(1), x.size(2), x.size(3) // 2
local_y = []
local_x = []
step_h, step_w = h // self.scale, w // self.scale
for i in range(0, self.scale):
for j in range(0, self.scale):
start_x, start_y = i * step_h, j * step_w
end_x, end_y = min(start_x + step_h, h), min(start_y + step_w, w)
if i == (self.scale - 1):
end_x = h
if j == (self.scale - 1):
end_y = w
local_x += [start_x, end_x]
local_y += [start_y, end_y]
value = self.f_value(x)
query = self.f_query(x)
key = self.f_key(x)
value = torch.stack([value[:, :, :, :w], value[:, :, :, w:]], 4) # B*N*H*W*2
query = torch.stack([query[:, :, :, :w], query[:, :, :, w:]], 4) # B*N*H*W*2
key = torch.stack([key[:, :, :, :w], key[:, :, :, w:]], 4) # B*N*H*W*2
local_block_cnt = 2 * self.scale * self.scale
# self-attention func
def func(value_local, query_local, key_local):
batch_size_new = value_local.size(0)
h_local, w_local = value_local.size(2), value_local.size(3)
value_local = value_local.contiguous().view(batch_size_new, self.value_channels, -1)
query_local = query_local.contiguous().view(batch_size_new, self.key_channels, -1)
query_local = query_local.permute(0, 2, 1)
key_local = key_local.contiguous().view(batch_size_new, self.key_channels, -1)
sim_map = torch.bmm(query_local, key_local) # batch matrix multiplication
sim_map = (self.key_channels ** -.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1)
context_local = torch.bmm(value_local, sim_map.permute(0, 2, 1))
# context_local = context_local.permute(0, 2, 1).contiguous()
context_local = context_local.view(batch_size_new, self.value_channels, h_local, w_local, 2)
return context_local
# Parallel Computing to speed up
# reshape value_local, q, k
v_list = [value[:, :, local_x[i]:local_x[i + 1], local_y[i]:local_y[i + 1]] for i in
range(0, local_block_cnt, 2)]
v_locals = torch.cat(v_list, dim=0)
q_list = [query[:, :, local_x[i]:local_x[i + 1], local_y[i]:local_y[i + 1]] for i in
range(0, local_block_cnt, 2)]
q_locals = torch.cat(q_list, dim=0)
k_list = [key[:, :, local_x[i]:local_x[i + 1], local_y[i]:local_y[i + 1]] for i in range(0, local_block_cnt, 2)]
k_locals = torch.cat(k_list, dim=0)
context_locals = func(v_locals, q_locals, k_locals)
context_list = []
for i in range(0, self.scale):
row_tmp = []
for j in range(0, self.scale):
left = batch_size * (j + i * self.scale)
right = batch_size * (j + i * self.scale) + batch_size
tmp = context_locals[left:right]
row_tmp.append(tmp)
context_list.append(torch.cat(row_tmp, 3))
context = torch.cat(context_list, 2)
context = torch.cat([context[:, :, :, :, 0], context[:, :, :, :, 1]], 3)
if self.ds != 1:
context = F.interpolate(context, [h * self.ds, 2 * w * self.ds])
return context
class PAMBlock(_PAMBlock):
def __init__(self, in_channels, key_channels=None, value_channels=None, scale=1, ds=1):
if key_channels == None:
key_channels = in_channels // 8
if value_channels == None:
value_channels = in_channels
super(PAMBlock, self).__init__(in_channels, key_channels, value_channels, scale, ds)
class PAM(nn.Module):
"""
PAM module
"""
def __init__(self, in_channels, out_channels, sizes=([1]), ds=1):
super(PAM, self).__init__()
self.group = len(sizes)
self.stages = []
self.ds = ds # output stride
self.value_channels = out_channels
self.key_channels = out_channels // 8
self.stages = nn.ModuleList(
[self._make_stage(in_channels, self.key_channels, self.value_channels, size, self.ds)
for size in sizes])
self.conv_bn = nn.Sequential(
nn.Conv2d(in_channels * self.group, out_channels, kernel_size=1, padding=0, bias=False),
# nn.BatchNorm2d(out_channels),
)
def _make_stage(self, in_channels, key_channels, value_channels, size, ds):
return PAMBlock(in_channels, key_channels, value_channels, size, ds)
def forward(self, feats):
priors = [stage(feats) for stage in self.stages]
# concat
context = []
for i in range(0, len(priors)):
context += [priors[i]]
output = self.conv_bn(torch.cat(context, 1))
return output
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class CDSA(nn.Module):
"""self attention module for change detection
"""
def __init__(self, in_c, ds=1, mode='BAM'):
super(CDSA, self).__init__()
self.in_C = in_c
self.ds = ds
self.mode = mode
if self.mode == 'BAM':
self.Self_Att = BAM(self.in_C, ds=self.ds)
elif self.mode == 'PAM':
self.Self_Att = PAM(in_channels=self.in_C, out_channels=self.in_C, sizes=[1, 2, 4, 8], ds=self.ds)
elif self.mode == 'None':
self.Self_Att = nn.Identity()
self.apply(weights_init)
def forward(self, x1, x2):
height = x1.shape[3]
x = torch.cat((x1, x2), 3)
x = self.Self_Att(x)
return x[:, :, :, 0:height], x[:, :, :, height:]
@MODELS.register_module()
class STAHead(BaseDecodeHead):
"""The Head of STANet.
Args:
sa_mode:
interpolate_mode: The interpolate mode of MLP head upsample operation.
Default: 'bilinear'.
"""
def __init__(
self,
sa_mode='PAM',
sa_in_channels=256,
sa_ds=1,
distance_threshold=1,
**kwargs):
super().__init__(input_transform='multiple_select', num_classes=1, **kwargs)
num_inputs = len(self.in_channels)
assert num_inputs == len(self.in_index)
self.distance_threshold = distance_threshold
self.fpn_convs = nn.ModuleList()
for in_channels in self.in_channels:
fpn_conv = ConvModule(
in_channels,
self.channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
inplace=False)
self.fpn_convs.append(fpn_conv)
self.fpn_bottleneck = nn.Sequential(
ConvModule(
len(self.in_channels) * self.channels,
sa_in_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
nn.Dropout(0.5),
ConvModule(
sa_in_channels,
sa_in_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
)
self.netA = CDSA(in_c=sa_in_channels, ds=sa_ds, mode=sa_mode)
self.calc_dist = nn.PairwiseDistance(keepdim=True)
self.conv_seg = nn.Identity()
def base_forward(self, inputs):
fpn_outs = [
self.fpn_convs[i](inputs[i])
for i in range(len(self.in_channels))
]
for i in range(len(self.in_channels)):
fpn_outs[i] = resize(
fpn_outs[i],
size=fpn_outs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners)
fpn_outs = torch.cat(fpn_outs, dim=1)
feats = self.fpn_bottleneck(fpn_outs)
return feats
def forward(self, inputs):
# Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32
inputs = self._transform_inputs(inputs)
inputs1 = []
inputs2 = []
for input in inputs:
f1, f2 = torch.chunk(input, 2, dim=1)
inputs1.append(f1)
inputs2.append(f2)
f1 = self.base_forward(inputs1)
f2 = self.base_forward(inputs2)
f1, f2 = self.netA(f1, f2)
# if you use PyTorch<=1.8, there may be some problems.
# see https://github.com/justchenhao/STANet/issues/85
f1 = f1.permute(0, 2, 3, 1)
f2 = f2.permute(0, 2, 3, 1)
dist = self.calc_dist(f1, f2).permute(0, 3, 1, 2)
dist = F.interpolate(dist, size=inputs[0].shape[2:], mode='bilinear', align_corners=True)
return dist
def predict_by_feat(self, seg_logits, batch_img_metas):
"""Transform a batch of output seg_logits to the input shape.
Args:
seg_logits (Tensor): The output from decode head forward function.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
Returns:
Tensor: Outputs segmentation logits map.
"""
seg_logits_copy = copy.deepcopy(seg_logits)
seg_logits[seg_logits_copy > self.distance_threshold] = 100
seg_logits[seg_logits_copy <= self.distance_threshold] = -100
seg_logits = resize(
input=seg_logits,
size=batch_img_metas[0]['img_shape'],
mode='bilinear',
align_corners=self.align_corners)
return seg_logits