Building_area / model /model.py
Pavan2k4's picture
app
35d85a5
import torch
import torch.nn as nn
import torch.functional as F
from bn import batch_norm
from residual import residual
from decoder import decoder
class reunet(nn.Module):
def __init__(self):
super().__init__()
#encoder 1
self.conv1 = nn.Conv2d(3, 64, kernel_size = 3, padding = 1, stride = 1)
self.bn1 = batch_norm(64)
self.conv2 = nn.Conv2d(64,64, kernel_size =3, padding = 1, stride = 1)
self.conv3 = nn.Conv2d(3, 64, kernel_size = 1, padding = 0, stride = 1)
#encoder2
self.enc2 = residual(64, 128, stride = 2)
#encoder3
self.enc3 = residual(128,256, stride = 2)
#bridge
self.bridge = residual(256,512, stride = 2)
#decoder
self.d1 = decoder(512, 256)
self.d2 = decoder(256, 128)
self.d3 = decoder(128,64)
#output
self.output = nn.Conv2d(64,1,kernel_size = 1, padding = 0)
self.sigmoid = nn.Sigmoid()
def forward(self, input):
'''enc1'''
x = self.conv1(input)
x = self.bn1(x)
x = self.conv2(x)
residual = self.conv3(input)
skip1 = x+residual
'''enc 2 and 3'''
skip2 = self.enc2(skip1)
skip3 = self.enc3(skip2)
'''bridge'''
b = self.bridge(skip3)
'''decoder'''
d1 = self.d1(b, skip3)
d2 = self.d2(d1, skip2)
d3 = self.d3(d2, skip1)
'''output'''
output = self.output(d3)
output = self.sigmoid(output)
return output