Building_area / model /decoder.py
Pavan2k4's picture
app
35d85a5
import torch
import torch.nn as nn
import torch.functional as F
from bn import batch_norm
from residual import residual
class decoder(nn.Module):
def __init__(self, inp, out):
super().__init__()
self.upsample = nn.Upsample(scale_factor=2, mode = 'bilinear', align_corners = True)
self.block = residual(inp+out, out)
def forward(self, x, skip):
x = self.upsample(x)
x = torch.cat([x, skip], axis = 1)
x = self.block(x)
return x