Pavan2k4's picture
app
35d85a5
raw
history blame contribute delete
389 Bytes
import torch
import torch.nn as nn
from channel_att import Channel_attention
from spatial_att import Spatial_attention
class CBAM(nn.Module):
def __init__(self, ch):
super().__init__()
self.channel = Channel_attention(ch)
self.spatial = Spatial_attention()
def forward(self, x):
x1 = self.channel(x)
x2 = self.spatial(x1)
return x2