Spaces:
Sleeping
Sleeping
File size: 1,566 Bytes
35d85a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import torch
import torch.nn as nn
import torch.functional as F
from bn import batch_norm
from residual import residual
from decoder import decoder
class reunet(nn.Module):
def __init__(self):
super().__init__()
#encoder 1
self.conv1 = nn.Conv2d(3, 64, kernel_size = 3, padding = 1, stride = 1)
self.bn1 = batch_norm(64)
self.conv2 = nn.Conv2d(64,64, kernel_size =3, padding = 1, stride = 1)
self.conv3 = nn.Conv2d(3, 64, kernel_size = 1, padding = 0, stride = 1)
#encoder2
self.enc2 = residual(64, 128, stride = 2)
#encoder3
self.enc3 = residual(128,256, stride = 2)
#bridge
self.bridge = residual(256,512, stride = 2)
#decoder
self.d1 = decoder(512, 256)
self.d2 = decoder(256, 128)
self.d3 = decoder(128,64)
#output
self.output = nn.Conv2d(64,1,kernel_size = 1, padding = 0)
self.sigmoid = nn.Sigmoid()
def forward(self, input):
'''enc1'''
x = self.conv1(input)
x = self.bn1(x)
x = self.conv2(x)
residual = self.conv3(input)
skip1 = x+residual
'''enc 2 and 3'''
skip2 = self.enc2(skip1)
skip3 = self.enc3(skip2)
'''bridge'''
b = self.bridge(skip3)
'''decoder'''
d1 = self.d1(b, skip3)
d2 = self.d2(d1, skip2)
d3 = self.d3(d2, skip1)
'''output'''
output = self.output(d3)
output = self.sigmoid(output)
return output
|