Spaces:
Sleeping
Sleeping
File size: 554 Bytes
35d85a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch
import torch.nn as nn
class Spatial_attention(nn.Module):
def __init__(self, kernel = 7):
super().__init__()
self.conv = nn.Conv2d(2, 1, padding = 3, kernel_size=kernel, bias = False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x1 = torch.mean(x, dim=1, keepdim = True)
x2, _ = torch.max(x, dim = 1, keepdim = True)
f = torch.concat([x1, x2], dim = 1)
f_c = self.conv(f)
f_s = self.sigmoid(f_c)
f_final = x * f_s
return f_final
|