import torch import torch.nn as nn class Spatial_attention(nn.Module): def __init__(self, kernel = 7): super().__init__() self.conv = nn.Conv2d(2, 1, padding = 3, kernel_size=kernel, bias = False) self.sigmoid = nn.Sigmoid() def forward(self, x): x1 = torch.mean(x, dim=1, keepdim = True) x2, _ = torch.max(x, dim = 1, keepdim = True) f = torch.concat([x1, x2], dim = 1) f_c = self.conv(f) f_s = self.sigmoid(f_c) f_final = x * f_s return f_final