import torch import torch.nn as nn import torch.functional as F from bn import batch_norm from residual import residual class decoder(nn.Module): def __init__(self, inp, out): super().__init__() self.upsample = nn.Upsample(scale_factor=2, mode = 'bilinear', align_corners = True) self.block = residual(inp+out, out) def forward(self, x, skip): x = self.upsample(x) x = torch.cat([x, skip], axis = 1) x = self.block(x) return x