Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.functional as F | |
from bn import batch_norm | |
from residual import residual | |
from decoder import decoder | |
class reunet(nn.Module): | |
def __init__(self): | |
super().__init__() | |
#encoder 1 | |
self.conv1 = nn.Conv2d(3, 64, kernel_size = 3, padding = 1, stride = 1) | |
self.bn1 = batch_norm(64) | |
self.conv2 = nn.Conv2d(64,64, kernel_size =3, padding = 1, stride = 1) | |
self.conv3 = nn.Conv2d(3, 64, kernel_size = 1, padding = 0, stride = 1) | |
#encoder2 | |
self.enc2 = residual(64, 128, stride = 2) | |
#encoder3 | |
self.enc3 = residual(128,256, stride = 2) | |
#bridge | |
self.bridge = residual(256,512, stride = 2) | |
#decoder | |
self.d1 = decoder(512, 256) | |
self.d2 = decoder(256, 128) | |
self.d3 = decoder(128,64) | |
#output | |
self.output = nn.Conv2d(64,1,kernel_size = 1, padding = 0) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, input): | |
'''enc1''' | |
x = self.conv1(input) | |
x = self.bn1(x) | |
x = self.conv2(x) | |
residual = self.conv3(input) | |
skip1 = x+residual | |
'''enc 2 and 3''' | |
skip2 = self.enc2(skip1) | |
skip3 = self.enc3(skip2) | |
'''bridge''' | |
b = self.bridge(skip3) | |
'''decoder''' | |
d1 = self.d1(b, skip3) | |
d2 = self.d2(d1, skip2) | |
d3 = self.d3(d2, skip1) | |
'''output''' | |
output = self.output(d3) | |
output = self.sigmoid(output) | |
return output | |