File size: 389 Bytes
35d85a5
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch 
import torch.nn as nn
from channel_att import Channel_attention
from spatial_att import Spatial_attention
class CBAM(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.channel = Channel_attention(ch)
        self.spatial = Spatial_attention()

    def forward(self, x):
        x1 = self.channel(x)
        x2 = self.spatial(x1)
        return x2