File size: 3,007 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Copyright (c) Open-CD. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule

from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.models.utils import resize
from opencd.registry import MODELS


@MODELS.register_module()
class TinyHead(BaseDecodeHead):
    """
    This head is the implementation of `TinyCDv2
    <https://arxiv.org/abs/>`_.
    Args:
        feature_strides (tuple[int]): The strides for input feature maps.
            stack_lateral. All strides suppose to be power of 2. The first
            one is of largest resolution.
        priori_attn (bool): Whether use Priori Guiding Connection.
            Default to False.
    """

    def __init__(self, feature_strides, priori_attn=False, **kwargs):
        super().__init__(input_transform='multiple_select', **kwargs)
        assert len(feature_strides) == len(self.in_channels)
        assert min(feature_strides) == feature_strides[0]
        if priori_attn:
            attn_channels = self.in_channels[0]
            self.in_channels = self.in_channels[1:]
            feature_strides = feature_strides[1:]
        self.feature_strides = feature_strides
        self.priori_attn = priori_attn


        self.scale_heads = nn.ModuleList()
        for i in range(len(feature_strides)):
            scale_head = []
            scale_head.append(
                ConvModule(
                    in_channels=self.in_channels[i],
                    out_channels=self.channels,
                    kernel_size=1,
                    stride=1,
                    groups=1,
                    norm_cfg=self.norm_cfg,
                    act_cfg=self.act_cfg))
            self.scale_heads.append(nn.Sequential(*scale_head))

        if self.priori_attn:
            self.gen_diff_attn = ConvModule(
                in_channels=attn_channels // 2,
                out_channels=self.channels,
                kernel_size=1,
                stride=1,
                groups=1,
                norm_cfg=None,
                act_cfg=None
            )

    def forward(self, inputs):

        x = self._transform_inputs(inputs)

        if self.priori_attn:
            early_x = x[0]
            x = x[1:]

        output = self.scale_heads[0](x[0])
        for i in range(1, len(self.feature_strides)):
            # non inplace
            output = output + resize(
                self.scale_heads[i](x[i]),
                size=output.shape[2:],
                mode='bilinear',
                align_corners=self.align_corners)

        if self.priori_attn:
            x1_, x2_ = torch.chunk(early_x, 2, dim=1)
            diff_x = torch.abs(x1_ - x2_)
            diff_x = self.gen_diff_attn(diff_x)
            if diff_x.shape != output.shape:
                output = resize(output, diff_x.shape[2:], mode='bilinear', align_corners=self.align_corners)
            output = output * torch.sigmoid(diff_x) + output

        output = self.cls_seg(output)
        return output