File size: 4,426 Bytes
35d85a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch 
import torch.nn as nn

class Spatial_attention(nn.Module):
    def __init__(self, kernel = 7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, padding = 3, kernel_size=kernel, bias = False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        x1 = torch.mean(x, dim=1, keepdim = True)
        x2, _ = torch.max(x, dim = 1, keepdim = True)
        f = torch.concat([x1, x2], dim = 1)
        f_c = self.conv(f)
        f_s = self.sigmoid(f_c)
        f_final = x * f_s

        return f_final
    
class Channel_attention(nn.Module):
    def __init__(self,ch, ratio = 8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.mlp = nn.Sequential(
            nn.Linear(ch, ch//ratio, bias = False),
            nn.ReLU(inplace = True),
            nn.Linear( ch//ratio,ch, bias = False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x1 = self.avg_pool(x).squeeze(-1).squeeze(-1)
        x1 = self.mlp(x1)
        # x2
        x2 = self.max_pool(x).squeeze(-1).squeeze(-1)
        x2 = self.mlp(x2)
        #concat
        f = x1+x2
        f_s = self.sigmoid(f).unsqueeze(-1).unsqueeze(-1)
        f_final = x * f_s
        return f_final
class CBAM(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.channel = Channel_attention(ch)
        self.spatial = Spatial_attention()

    def forward(self, x):
        x1 = self.channel(x)
        x2 = self.spatial(x1)
        return x2
    
class residual(nn.Module):
    def __init__(self, inp, out, stride=1):
        super().__init__()
        self.bn1 = batch_norm(inp)
        self.conv1 = nn.Conv2d(inp, out, kernel_size=3, padding=1, stride=stride)
        self.bn2 = batch_norm(out)
        self.conv2 = nn.Conv2d(out, out, kernel_size=3, padding=1, stride=1)
        # skip connection
        self.concat = nn.Conv2d(inp, out, kernel_size=1, padding=0, stride=stride)
        # Add CBAM
        self.cbam = CBAM(out)

    def forward(self, input):
        x = self.bn1(input)
        x = self.conv1(x)
        x = self.bn2(x)
        x = self.conv2(x)
        x = self.cbam(x)  # Apply CBAM
        skip = self.concat(input)
        skip = x + skip
        return skip
    
class batch_norm(nn.Module):
    def __init__(self, inp):
        super().__init__()
        self.batch = nn.BatchNorm2d(inp)
        self.relu = nn.ReLU()
    def forward(self, x):
        b = self.batch(x)
        op = self.relu(b)
        return op
    
class decoder(nn.Module):
    def __init__(self, inp, out):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode = 'bilinear', align_corners = True)
        self.block = residual(inp+out, out)
    def forward(self, x, skip):
        x = self.upsample(x)
        x = torch.cat([x, skip], axis = 1)
        x = self.block(x)
        return x
class reunet_cbam(nn.Module):
    def __init__(self):
        super().__init__()
        # encoder 1
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, stride=1)
        self.bn1 = batch_norm(64)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1)
        self.conv3 = nn.Conv2d(3, 64, kernel_size=1, padding=0, stride=1)
        self.cbam1 = CBAM(64)  # Add CBAM for encoder 1
        # encoder2
        self.enc2 = residual(64, 128, stride=2)
        # encoder3
        self.enc3 = residual(128, 256, stride=2)
        # bridge
        self.bridge = residual(256, 512, stride=2)
        # decoder
        self.d1 = decoder(512, 256)
        self.d2 = decoder(256, 128)
        self.d3 = decoder(128, 64)
        
        # output
        self.output = nn.Conv2d(64, 1, kernel_size=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input):
        '''enc1'''
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.cbam1(x)  # Apply CBAM
        residual = self.conv3(input)
        skip1 = x + residual
        '''enc 2 and 3'''
        skip2 = self.enc2(skip1)
        skip3 = self.enc3(skip2)
        '''bridge'''
        b = self.bridge(skip3)
        '''decoder'''
        d1 = self.d1(b, skip3)
        d2 = self.d2(d1, skip2)
        d3 = self.d3(d2, skip1)
        '''output'''
        output = self.output(d3)
        output = self.sigmoid(output)
        return output