import torch import torch.nn as nn class Channel_attention(nn.Module): def __init__(self,ch, ratio = 8): super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) self.mlp = nn.Sequential( nn.Linear(ch, ch//ratio, bias = False), nn.ReLU(inplace = True), nn.Linear( ch//ratio,ch, bias = False) ) self.sigmoid = nn.Sigmoid() def forward(self, x): x1 = self.avg_pool(x).squeeze(-1).squeeze(-1) x1 = self.mlp(x1) # x2 x2 = self.max_pool(x).squeeze(-1).squeeze(-1) x2 = self.mlp(x2) #concat f = x1+x2 f_s = self.sigmoid(f).unsqueeze(-1).unsqueeze(-1) f_final = x * f_s return f_final