import torch import torch.nn as nn class Spatial_attention(nn.Module): def __init__(self, kernel = 7): super().__init__() self.conv = nn.Conv2d(2, 1, padding = 3, kernel_size=kernel, bias = False) self.sigmoid = nn.Sigmoid() def forward(self, x): x1 = torch.mean(x, dim=1, keepdim = True) x2, _ = torch.max(x, dim = 1, keepdim = True) f = torch.concat([x1, x2], dim = 1) f_c = self.conv(f) f_s = self.sigmoid(f_c) f_final = x * f_s return f_final class Channel_attention(nn.Module): def __init__(self,ch, ratio = 8): super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.mlp = nn.Sequential( nn.Linear(ch, ch//ratio, bias = False), nn.ReLU(inplace = True), nn.Linear( ch//ratio,ch, bias = False) ) self.sigmoid = nn.Sigmoid() def forward(self, x): x1 = self.avg_pool(x).squeeze(-1).squeeze(-1) x1 = self.mlp(x1) # x2 x2 = self.max_pool(x).squeeze(-1).squeeze(-1) x2 = self.mlp(x2) #concat f = x1+x2 f_s = self.sigmoid(f).unsqueeze(-1).unsqueeze(-1) f_final = x * f_s return f_final class CBAM(nn.Module): def __init__(self, ch): super().__init__() self.channel = Channel_attention(ch) self.spatial = Spatial_attention() def forward(self, x): x1 = self.channel(x) x2 = self.spatial(x1) return x2 class residual(nn.Module): def __init__(self, inp, out, stride=1): super().__init__() self.bn1 = batch_norm(inp) self.conv1 = nn.Conv2d(inp, out, kernel_size=3, padding=1, stride=stride) self.bn2 = batch_norm(out) self.conv2 = nn.Conv2d(out, out, kernel_size=3, padding=1, stride=1) # skip connection self.concat = nn.Conv2d(inp, out, kernel_size=1, padding=0, stride=stride) # Add CBAM self.cbam = CBAM(out) def forward(self, input): x = self.bn1(input) x = self.conv1(x) x = self.bn2(x) x = self.conv2(x) x = self.cbam(x) # Apply CBAM skip = self.concat(input) skip = x + skip return skip class batch_norm(nn.Module): def __init__(self, inp): super().__init__() self.batch = nn.BatchNorm2d(inp) self.relu = nn.ReLU() def forward(self, x): b = self.batch(x) op = self.relu(b) return op class decoder(nn.Module): def __init__(self, inp, out): super().__init__() self.upsample = nn.Upsample(scale_factor=2, mode = 'bilinear', align_corners = True) self.block = residual(inp+out, out) def forward(self, x, skip): x = self.upsample(x) x = torch.cat([x, skip], axis = 1) x = self.block(x) return x class reunet_cbam(nn.Module): def __init__(self): super().__init__() # encoder 1 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, stride=1) self.bn1 = batch_norm(64) self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1) self.conv3 = nn.Conv2d(3, 64, kernel_size=1, padding=0, stride=1) self.cbam1 = CBAM(64) # Add CBAM for encoder 1 # encoder2 self.enc2 = residual(64, 128, stride=2) # encoder3 self.enc3 = residual(128, 256, stride=2) # bridge self.bridge = residual(256, 512, stride=2) # decoder self.d1 = decoder(512, 256) self.d2 = decoder(256, 128) self.d3 = decoder(128, 64) # output self.output = nn.Conv2d(64, 1, kernel_size=1, padding=0) self.sigmoid = nn.Sigmoid() def forward(self, input): '''enc1''' x = self.conv1(input) x = self.bn1(x) x = self.conv2(x) x = self.cbam1(x) # Apply CBAM residual = self.conv3(input) skip1 = x + residual '''enc 2 and 3''' skip2 = self.enc2(skip1) skip3 = self.enc3(skip2) '''bridge''' b = self.bridge(skip3) '''decoder''' d1 = self.d1(b, skip3) d2 = self.d2(d1, skip2) d3 = self.d3(d2, skip1) '''output''' output = self.output(d3) output = self.sigmoid(output) return output