Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
class Spatial_attention(nn.Module): | |
def __init__(self, kernel = 7): | |
super().__init__() | |
self.conv = nn.Conv2d(2, 1, padding = 3, kernel_size=kernel, bias = False) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, x): | |
x1 = torch.mean(x, dim=1, keepdim = True) | |
x2, _ = torch.max(x, dim = 1, keepdim = True) | |
f = torch.concat([x1, x2], dim = 1) | |
f_c = self.conv(f) | |
f_s = self.sigmoid(f_c) | |
f_final = x * f_s | |
return f_final | |
class Channel_attention(nn.Module): | |
def __init__(self,ch, ratio = 8): | |
super().__init__() | |
self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
self.max_pool = nn.AdaptiveMaxPool2d(1) | |
self.mlp = nn.Sequential( | |
nn.Linear(ch, ch//ratio, bias = False), | |
nn.ReLU(inplace = True), | |
nn.Linear( ch//ratio,ch, bias = False) | |
) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, x): | |
x1 = self.avg_pool(x).squeeze(-1).squeeze(-1) | |
x1 = self.mlp(x1) | |
# x2 | |
x2 = self.max_pool(x).squeeze(-1).squeeze(-1) | |
x2 = self.mlp(x2) | |
#concat | |
f = x1+x2 | |
f_s = self.sigmoid(f).unsqueeze(-1).unsqueeze(-1) | |
f_final = x * f_s | |
return f_final | |
class CBAM(nn.Module): | |
def __init__(self, ch): | |
super().__init__() | |
self.channel = Channel_attention(ch) | |
self.spatial = Spatial_attention() | |
def forward(self, x): | |
x1 = self.channel(x) | |
x2 = self.spatial(x1) | |
return x2 | |
class residual(nn.Module): | |
def __init__(self, inp, out, stride=1): | |
super().__init__() | |
self.bn1 = batch_norm(inp) | |
self.conv1 = nn.Conv2d(inp, out, kernel_size=3, padding=1, stride=stride) | |
self.bn2 = batch_norm(out) | |
self.conv2 = nn.Conv2d(out, out, kernel_size=3, padding=1, stride=1) | |
# skip connection | |
self.concat = nn.Conv2d(inp, out, kernel_size=1, padding=0, stride=stride) | |
# Add CBAM | |
self.cbam = CBAM(out) | |
def forward(self, input): | |
x = self.bn1(input) | |
x = self.conv1(x) | |
x = self.bn2(x) | |
x = self.conv2(x) | |
x = self.cbam(x) # Apply CBAM | |
skip = self.concat(input) | |
skip = x + skip | |
return skip | |
class batch_norm(nn.Module): | |
def __init__(self, inp): | |
super().__init__() | |
self.batch = nn.BatchNorm2d(inp) | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
b = self.batch(x) | |
op = self.relu(b) | |
return op | |
class decoder(nn.Module): | |
def __init__(self, inp, out): | |
super().__init__() | |
self.upsample = nn.Upsample(scale_factor=2, mode = 'bilinear', align_corners = True) | |
self.block = residual(inp+out, out) | |
def forward(self, x, skip): | |
x = self.upsample(x) | |
x = torch.cat([x, skip], axis = 1) | |
x = self.block(x) | |
return x | |
class reunet_cbam(nn.Module): | |
def __init__(self): | |
super().__init__() | |
# encoder 1 | |
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, stride=1) | |
self.bn1 = batch_norm(64) | |
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1) | |
self.conv3 = nn.Conv2d(3, 64, kernel_size=1, padding=0, stride=1) | |
self.cbam1 = CBAM(64) # Add CBAM for encoder 1 | |
# encoder2 | |
self.enc2 = residual(64, 128, stride=2) | |
# encoder3 | |
self.enc3 = residual(128, 256, stride=2) | |
# bridge | |
self.bridge = residual(256, 512, stride=2) | |
# decoder | |
self.d1 = decoder(512, 256) | |
self.d2 = decoder(256, 128) | |
self.d3 = decoder(128, 64) | |
# output | |
self.output = nn.Conv2d(64, 1, kernel_size=1, padding=0) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, input): | |
'''enc1''' | |
x = self.conv1(input) | |
x = self.bn1(x) | |
x = self.conv2(x) | |
x = self.cbam1(x) # Apply CBAM | |
residual = self.conv3(input) | |
skip1 = x + residual | |
'''enc 2 and 3''' | |
skip2 = self.enc2(skip1) | |
skip3 = self.enc3(skip2) | |
'''bridge''' | |
b = self.bridge(skip3) | |
'''decoder''' | |
d1 = self.d1(b, skip3) | |
d2 = self.d2(d1, skip2) | |
d3 = self.d3(d2, skip1) | |
'''output''' | |
output = self.output(d3) | |
output = self.sigmoid(output) | |
return output |