Building_area / model /CBAM /spatial_att.py
Pavan2k4's picture
app
35d85a5
raw
history blame contribute delete
554 Bytes
import torch
import torch.nn as nn
class Spatial_attention(nn.Module):
def __init__(self, kernel = 7):
super().__init__()
self.conv = nn.Conv2d(2, 1, padding = 3, kernel_size=kernel, bias = False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x1 = torch.mean(x, dim=1, keepdim = True)
x2, _ = torch.max(x, dim = 1, keepdim = True)
f = torch.concat([x1, x2], dim = 1)
f_c = self.conv(f)
f_s = self.sigmoid(f_c)
f_final = x * f_s
return f_final