Building_area / model /CBAM /reunet_cbam.py
Pavan2k4's picture
app
35d85a5
import torch
import torch.nn as nn
class Spatial_attention(nn.Module):
def __init__(self, kernel = 7):
super().__init__()
self.conv = nn.Conv2d(2, 1, padding = 3, kernel_size=kernel, bias = False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x1 = torch.mean(x, dim=1, keepdim = True)
x2, _ = torch.max(x, dim = 1, keepdim = True)
f = torch.concat([x1, x2], dim = 1)
f_c = self.conv(f)
f_s = self.sigmoid(f_c)
f_final = x * f_s
return f_final
class Channel_attention(nn.Module):
def __init__(self,ch, ratio = 8):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.mlp = nn.Sequential(
nn.Linear(ch, ch//ratio, bias = False),
nn.ReLU(inplace = True),
nn.Linear( ch//ratio,ch, bias = False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x1 = self.avg_pool(x).squeeze(-1).squeeze(-1)
x1 = self.mlp(x1)
# x2
x2 = self.max_pool(x).squeeze(-1).squeeze(-1)
x2 = self.mlp(x2)
#concat
f = x1+x2
f_s = self.sigmoid(f).unsqueeze(-1).unsqueeze(-1)
f_final = x * f_s
return f_final
class CBAM(nn.Module):
def __init__(self, ch):
super().__init__()
self.channel = Channel_attention(ch)
self.spatial = Spatial_attention()
def forward(self, x):
x1 = self.channel(x)
x2 = self.spatial(x1)
return x2
class residual(nn.Module):
def __init__(self, inp, out, stride=1):
super().__init__()
self.bn1 = batch_norm(inp)
self.conv1 = nn.Conv2d(inp, out, kernel_size=3, padding=1, stride=stride)
self.bn2 = batch_norm(out)
self.conv2 = nn.Conv2d(out, out, kernel_size=3, padding=1, stride=1)
# skip connection
self.concat = nn.Conv2d(inp, out, kernel_size=1, padding=0, stride=stride)
# Add CBAM
self.cbam = CBAM(out)
def forward(self, input):
x = self.bn1(input)
x = self.conv1(x)
x = self.bn2(x)
x = self.conv2(x)
x = self.cbam(x) # Apply CBAM
skip = self.concat(input)
skip = x + skip
return skip
class batch_norm(nn.Module):
def __init__(self, inp):
super().__init__()
self.batch = nn.BatchNorm2d(inp)
self.relu = nn.ReLU()
def forward(self, x):
b = self.batch(x)
op = self.relu(b)
return op
class decoder(nn.Module):
def __init__(self, inp, out):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2, mode = 'bilinear', align_corners = True)
self.block = residual(inp+out, out)
def forward(self, x, skip):
x = self.upsample(x)
x = torch.cat([x, skip], axis = 1)
x = self.block(x)
return x
class reunet_cbam(nn.Module):
def __init__(self):
super().__init__()
# encoder 1
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, stride=1)
self.bn1 = batch_norm(64)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1)
self.conv3 = nn.Conv2d(3, 64, kernel_size=1, padding=0, stride=1)
self.cbam1 = CBAM(64) # Add CBAM for encoder 1
# encoder2
self.enc2 = residual(64, 128, stride=2)
# encoder3
self.enc3 = residual(128, 256, stride=2)
# bridge
self.bridge = residual(256, 512, stride=2)
# decoder
self.d1 = decoder(512, 256)
self.d2 = decoder(256, 128)
self.d3 = decoder(128, 64)
# output
self.output = nn.Conv2d(64, 1, kernel_size=1, padding=0)
self.sigmoid = nn.Sigmoid()
def forward(self, input):
'''enc1'''
x = self.conv1(input)
x = self.bn1(x)
x = self.conv2(x)
x = self.cbam1(x) # Apply CBAM
residual = self.conv3(input)
skip1 = x + residual
'''enc 2 and 3'''
skip2 = self.enc2(skip1)
skip3 = self.enc3(skip2)
'''bridge'''
b = self.bridge(skip3)
'''decoder'''
d1 = self.d1(b, skip3)
d2 = self.d2(d1, skip2)
d3 = self.d3(d2, skip1)
'''output'''
output = self.output(d3)
output = self.sigmoid(output)
return output