Spaces:
Sleeping
Sleeping
File size: 795 Bytes
35d85a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch
import torch.nn as nn
class Channel_attention(nn.Module):
def __init__(self,ch, ratio = 8):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.mlp = nn.Sequential(
nn.Linear(ch, ch//ratio, bias = False),
nn.ReLU(inplace = True),
nn.Linear( ch//ratio,ch, bias = False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x1 = self.avg_pool(x).squeeze(-1).squeeze(-1)
x1 = self.mlp(x1)
# x2
x2 = self.max_pool(x).squeeze(-1).squeeze(-1)
x2 = self.mlp(x2)
#concat
f = x1+x2
f_s = self.sigmoid(f).unsqueeze(-1).unsqueeze(-1)
f_final = x * f_s
return f_final |