import torch import torch.nn as nn from channel_att import Channel_attention from spatial_att import Spatial_attention class CBAM(nn.Module): def __init__(self, ch): super().__init__() self.channel = Channel_attention(ch) self.spatial = Spatial_attention() def forward(self, x): x1 = self.channel(x) x2 = self.spatial(x1) return x2