File size: 554 Bytes
35d85a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch 
import torch.nn as nn
class Spatial_attention(nn.Module):
    def __init__(self, kernel = 7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, padding = 3, kernel_size=kernel, bias = False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        x1 = torch.mean(x, dim=1, keepdim = True)
        x2, _ = torch.max(x, dim = 1, keepdim = True)
        f = torch.concat([x1, x2], dim = 1)
        f_c = self.conv(f)
        f_s = self.sigmoid(f_c)
        f_final = x * f_s

        return f_final