Spaces:
Runtime error
Runtime error
import copy | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import ConvModule | |
from mmseg.models.decode_heads.decode_head import BaseDecodeHead | |
from mmseg.models.utils import resize | |
from opencd.registry import MODELS | |
class BAM(nn.Module): | |
""" Basic self-attention module | |
""" | |
def __init__(self, in_dim, ds=8, activation=nn.ReLU): | |
super(BAM, self).__init__() | |
self.chanel_in = in_dim | |
self.key_channel = self.chanel_in // 8 | |
self.activation = activation | |
self.ds = ds # | |
self.pool = nn.AvgPool2d(self.ds) | |
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) | |
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) | |
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) | |
self.gamma = nn.Parameter(torch.zeros(1)) | |
self.softmax = nn.Softmax(dim=-1) # | |
def forward(self, input): | |
""" | |
inputs : | |
x : input feature maps( B X C X W X H) | |
returns : | |
out : self attention value + input feature | |
attention: B X N X N (N is Width*Height) | |
""" | |
x = self.pool(input) | |
m_batchsize, C, width, height = x.size() | |
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X C X (N)/(ds*ds) | |
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)/(ds*ds) | |
energy = torch.bmm(proj_query, proj_key) # transpose check | |
energy = (self.key_channel ** -.5) * energy | |
attention = self.softmax(energy) # BX (N) X (N)/(ds*ds)/(ds*ds) | |
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N | |
out = torch.bmm(proj_value, attention.permute(0, 2, 1)) | |
out = out.view(m_batchsize, C, width, height) | |
out = F.interpolate(out, [width * self.ds, height * self.ds]) | |
out = out + input | |
return out | |
class _PAMBlock(nn.Module): | |
''' | |
The basic implementation for self-attention block/non-local block | |
Input/Output: | |
N * C * H * (2*W) | |
Parameters: | |
in_channels : the dimension of the input feature map | |
key_channels : the dimension after the key/query transform | |
value_channels : the dimension after the value transform | |
scale : choose the scale to partition the input feature maps | |
ds : downsampling scale | |
''' | |
def __init__(self, in_channels, key_channels, value_channels, scale=1, ds=1): | |
super(_PAMBlock, self).__init__() | |
self.scale = scale | |
self.ds = ds | |
self.pool = nn.AvgPool2d(self.ds) | |
self.in_channels = in_channels | |
self.key_channels = key_channels | |
self.value_channels = value_channels | |
self.f_key = nn.Sequential( | |
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, | |
kernel_size=1, stride=1, padding=0), | |
nn.BatchNorm2d(self.key_channels) | |
) | |
self.f_query = nn.Sequential( | |
nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, | |
kernel_size=1, stride=1, padding=0), | |
nn.BatchNorm2d(self.key_channels) | |
) | |
self.f_value = nn.Conv2d(in_channels=self.in_channels, out_channels=self.value_channels, | |
kernel_size=1, stride=1, padding=0) | |
def forward(self, input): | |
x = input | |
if self.ds != 1: | |
x = self.pool(input) | |
# input shape: b,c,h,2w | |
batch_size, c, h, w = x.size(0), x.size(1), x.size(2), x.size(3) // 2 | |
local_y = [] | |
local_x = [] | |
step_h, step_w = h // self.scale, w // self.scale | |
for i in range(0, self.scale): | |
for j in range(0, self.scale): | |
start_x, start_y = i * step_h, j * step_w | |
end_x, end_y = min(start_x + step_h, h), min(start_y + step_w, w) | |
if i == (self.scale - 1): | |
end_x = h | |
if j == (self.scale - 1): | |
end_y = w | |
local_x += [start_x, end_x] | |
local_y += [start_y, end_y] | |
value = self.f_value(x) | |
query = self.f_query(x) | |
key = self.f_key(x) | |
value = torch.stack([value[:, :, :, :w], value[:, :, :, w:]], 4) # B*N*H*W*2 | |
query = torch.stack([query[:, :, :, :w], query[:, :, :, w:]], 4) # B*N*H*W*2 | |
key = torch.stack([key[:, :, :, :w], key[:, :, :, w:]], 4) # B*N*H*W*2 | |
local_block_cnt = 2 * self.scale * self.scale | |
# self-attention func | |
def func(value_local, query_local, key_local): | |
batch_size_new = value_local.size(0) | |
h_local, w_local = value_local.size(2), value_local.size(3) | |
value_local = value_local.contiguous().view(batch_size_new, self.value_channels, -1) | |
query_local = query_local.contiguous().view(batch_size_new, self.key_channels, -1) | |
query_local = query_local.permute(0, 2, 1) | |
key_local = key_local.contiguous().view(batch_size_new, self.key_channels, -1) | |
sim_map = torch.bmm(query_local, key_local) # batch matrix multiplication | |
sim_map = (self.key_channels ** -.5) * sim_map | |
sim_map = F.softmax(sim_map, dim=-1) | |
context_local = torch.bmm(value_local, sim_map.permute(0, 2, 1)) | |
# context_local = context_local.permute(0, 2, 1).contiguous() | |
context_local = context_local.view(batch_size_new, self.value_channels, h_local, w_local, 2) | |
return context_local | |
# Parallel Computing to speed up | |
# reshape value_local, q, k | |
v_list = [value[:, :, local_x[i]:local_x[i + 1], local_y[i]:local_y[i + 1]] for i in | |
range(0, local_block_cnt, 2)] | |
v_locals = torch.cat(v_list, dim=0) | |
q_list = [query[:, :, local_x[i]:local_x[i + 1], local_y[i]:local_y[i + 1]] for i in | |
range(0, local_block_cnt, 2)] | |
q_locals = torch.cat(q_list, dim=0) | |
k_list = [key[:, :, local_x[i]:local_x[i + 1], local_y[i]:local_y[i + 1]] for i in range(0, local_block_cnt, 2)] | |
k_locals = torch.cat(k_list, dim=0) | |
context_locals = func(v_locals, q_locals, k_locals) | |
context_list = [] | |
for i in range(0, self.scale): | |
row_tmp = [] | |
for j in range(0, self.scale): | |
left = batch_size * (j + i * self.scale) | |
right = batch_size * (j + i * self.scale) + batch_size | |
tmp = context_locals[left:right] | |
row_tmp.append(tmp) | |
context_list.append(torch.cat(row_tmp, 3)) | |
context = torch.cat(context_list, 2) | |
context = torch.cat([context[:, :, :, :, 0], context[:, :, :, :, 1]], 3) | |
if self.ds != 1: | |
context = F.interpolate(context, [h * self.ds, 2 * w * self.ds]) | |
return context | |
class PAMBlock(_PAMBlock): | |
def __init__(self, in_channels, key_channels=None, value_channels=None, scale=1, ds=1): | |
if key_channels == None: | |
key_channels = in_channels // 8 | |
if value_channels == None: | |
value_channels = in_channels | |
super(PAMBlock, self).__init__(in_channels, key_channels, value_channels, scale, ds) | |
class PAM(nn.Module): | |
""" | |
PAM module | |
""" | |
def __init__(self, in_channels, out_channels, sizes=([1]), ds=1): | |
super(PAM, self).__init__() | |
self.group = len(sizes) | |
self.stages = [] | |
self.ds = ds # output stride | |
self.value_channels = out_channels | |
self.key_channels = out_channels // 8 | |
self.stages = nn.ModuleList( | |
[self._make_stage(in_channels, self.key_channels, self.value_channels, size, self.ds) | |
for size in sizes]) | |
self.conv_bn = nn.Sequential( | |
nn.Conv2d(in_channels * self.group, out_channels, kernel_size=1, padding=0, bias=False), | |
# nn.BatchNorm2d(out_channels), | |
) | |
def _make_stage(self, in_channels, key_channels, value_channels, size, ds): | |
return PAMBlock(in_channels, key_channels, value_channels, size, ds) | |
def forward(self, feats): | |
priors = [stage(feats) for stage in self.stages] | |
# concat | |
context = [] | |
for i in range(0, len(priors)): | |
context += [priors[i]] | |
output = self.conv_bn(torch.cat(context, 1)) | |
return output | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
nn.init.normal_(m.weight.data, 0.0, 0.02) | |
elif classname.find('BatchNorm') != -1: | |
nn.init.normal_(m.weight.data, 1.0, 0.02) | |
nn.init.constant_(m.bias.data, 0) | |
class CDSA(nn.Module): | |
"""self attention module for change detection | |
""" | |
def __init__(self, in_c, ds=1, mode='BAM'): | |
super(CDSA, self).__init__() | |
self.in_C = in_c | |
self.ds = ds | |
self.mode = mode | |
if self.mode == 'BAM': | |
self.Self_Att = BAM(self.in_C, ds=self.ds) | |
elif self.mode == 'PAM': | |
self.Self_Att = PAM(in_channels=self.in_C, out_channels=self.in_C, sizes=[1, 2, 4, 8], ds=self.ds) | |
elif self.mode == 'None': | |
self.Self_Att = nn.Identity() | |
self.apply(weights_init) | |
def forward(self, x1, x2): | |
height = x1.shape[3] | |
x = torch.cat((x1, x2), 3) | |
x = self.Self_Att(x) | |
return x[:, :, :, 0:height], x[:, :, :, height:] | |
class STAHead(BaseDecodeHead): | |
"""The Head of STANet. | |
Args: | |
sa_mode: | |
interpolate_mode: The interpolate mode of MLP head upsample operation. | |
Default: 'bilinear'. | |
""" | |
def __init__( | |
self, | |
sa_mode='PAM', | |
sa_in_channels=256, | |
sa_ds=1, | |
distance_threshold=1, | |
**kwargs): | |
super().__init__(input_transform='multiple_select', num_classes=1, **kwargs) | |
num_inputs = len(self.in_channels) | |
assert num_inputs == len(self.in_index) | |
self.distance_threshold = distance_threshold | |
self.fpn_convs = nn.ModuleList() | |
for in_channels in self.in_channels: | |
fpn_conv = ConvModule( | |
in_channels, | |
self.channels, | |
1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg, | |
inplace=False) | |
self.fpn_convs.append(fpn_conv) | |
self.fpn_bottleneck = nn.Sequential( | |
ConvModule( | |
len(self.in_channels) * self.channels, | |
sa_in_channels, | |
3, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg), | |
nn.Dropout(0.5), | |
ConvModule( | |
sa_in_channels, | |
sa_in_channels, | |
3, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg) | |
) | |
self.netA = CDSA(in_c=sa_in_channels, ds=sa_ds, mode=sa_mode) | |
self.calc_dist = nn.PairwiseDistance(keepdim=True) | |
self.conv_seg = nn.Identity() | |
def base_forward(self, inputs): | |
fpn_outs = [ | |
self.fpn_convs[i](inputs[i]) | |
for i in range(len(self.in_channels)) | |
] | |
for i in range(len(self.in_channels)): | |
fpn_outs[i] = resize( | |
fpn_outs[i], | |
size=fpn_outs[0].shape[2:], | |
mode='bilinear', | |
align_corners=self.align_corners) | |
fpn_outs = torch.cat(fpn_outs, dim=1) | |
feats = self.fpn_bottleneck(fpn_outs) | |
return feats | |
def forward(self, inputs): | |
# Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 | |
inputs = self._transform_inputs(inputs) | |
inputs1 = [] | |
inputs2 = [] | |
for input in inputs: | |
f1, f2 = torch.chunk(input, 2, dim=1) | |
inputs1.append(f1) | |
inputs2.append(f2) | |
f1 = self.base_forward(inputs1) | |
f2 = self.base_forward(inputs2) | |
f1, f2 = self.netA(f1, f2) | |
# if you use PyTorch<=1.8, there may be some problems. | |
# see https://github.com/justchenhao/STANet/issues/85 | |
f1 = f1.permute(0, 2, 3, 1) | |
f2 = f2.permute(0, 2, 3, 1) | |
dist = self.calc_dist(f1, f2).permute(0, 3, 1, 2) | |
dist = F.interpolate(dist, size=inputs[0].shape[2:], mode='bilinear', align_corners=True) | |
return dist | |
def predict_by_feat(self, seg_logits, batch_img_metas): | |
"""Transform a batch of output seg_logits to the input shape. | |
Args: | |
seg_logits (Tensor): The output from decode head forward function. | |
batch_img_metas (list[dict]): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
Returns: | |
Tensor: Outputs segmentation logits map. | |
""" | |
seg_logits_copy = copy.deepcopy(seg_logits) | |
seg_logits[seg_logits_copy > self.distance_threshold] = 100 | |
seg_logits[seg_logits_copy <= self.distance_threshold] = -100 | |
seg_logits = resize( | |
input=seg_logits, | |
size=batch_img_metas[0]['img_shape'], | |
mode='bilinear', | |
align_corners=self.align_corners) | |
return seg_logits | |