File size: 795 Bytes
35d85a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch 
import torch.nn  as nn
class Channel_attention(nn.Module):
    def __init__(self,ch, ratio = 8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.mlp = nn.Sequential(
            nn.Linear(ch, ch//ratio, bias = False),
            nn.ReLU(inplace = True),
            nn.Linear( ch//ratio,ch, bias = False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x1 = self.avg_pool(x).squeeze(-1).squeeze(-1)
        x1 = self.mlp(x1)
        # x2
        x2 = self.max_pool(x).squeeze(-1).squeeze(-1)
        x2 = self.mlp(x2)
        #concat
        f = x1+x2
        f_s = self.sigmoid(f).unsqueeze(-1).unsqueeze(-1)
        f_final = x * f_s
        return f_final