Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
class Spatial_attention(nn.Module): | |
def __init__(self, kernel = 7): | |
super().__init__() | |
self.conv = nn.Conv2d(2, 1, padding = 3, kernel_size=kernel, bias = False) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, x): | |
x1 = torch.mean(x, dim=1, keepdim = True) | |
x2, _ = torch.max(x, dim = 1, keepdim = True) | |
f = torch.concat([x1, x2], dim = 1) | |
f_c = self.conv(f) | |
f_s = self.sigmoid(f_c) | |
f_final = x * f_s | |
return f_final | |