import torch import torch.nn as nn import torch.functional as F from bn import batch_norm from residual import residual from decoder import decoder class reunet(nn.Module): def __init__(self): super().__init__() #encoder 1 self.conv1 = nn.Conv2d(3, 64, kernel_size = 3, padding = 1, stride = 1) self.bn1 = batch_norm(64) self.conv2 = nn.Conv2d(64,64, kernel_size =3, padding = 1, stride = 1) self.conv3 = nn.Conv2d(3, 64, kernel_size = 1, padding = 0, stride = 1) #encoder2 self.enc2 = residual(64, 128, stride = 2) #encoder3 self.enc3 = residual(128,256, stride = 2) #bridge self.bridge = residual(256,512, stride = 2) #decoder self.d1 = decoder(512, 256) self.d2 = decoder(256, 128) self.d3 = decoder(128,64) #output self.output = nn.Conv2d(64,1,kernel_size = 1, padding = 0) self.sigmoid = nn.Sigmoid() def forward(self, input): '''enc1''' x = self.conv1(input) x = self.bn1(x) x = self.conv2(x) residual = self.conv3(input) skip1 = x+residual '''enc 2 and 3''' skip2 = self.enc2(skip1) skip3 = self.enc3(skip2) '''bridge''' b = self.bridge(skip3) '''decoder''' d1 = self.d1(b, skip3) d2 = self.d2(d1, skip2) d3 = self.d3(d2, skip1) '''output''' output = self.output(d3) output = self.sigmoid(output) return output