File size: 1,566 Bytes
35d85a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
import torch.functional as F

from bn import batch_norm
from residual import residual 
from decoder import decoder

class reunet(nn.Module):
    def __init__(self):
        super().__init__()
        #encoder 1
        self.conv1 = nn.Conv2d(3, 64, kernel_size = 3, padding = 1, stride = 1)
        self.bn1 = batch_norm(64)
        self.conv2 = nn.Conv2d(64,64, kernel_size =3, padding = 1, stride = 1)
        self.conv3 = nn.Conv2d(3, 64, kernel_size = 1, padding = 0, stride = 1)

        #encoder2
        self.enc2 = residual(64, 128, stride = 2)
        #encoder3
        self.enc3 = residual(128,256, stride = 2)

        #bridge
        self.bridge = residual(256,512, stride = 2)

        #decoder
        self.d1 = decoder(512, 256)
        self.d2 = decoder(256, 128)
        self.d3 = decoder(128,64)
        
        #output
        self.output = nn.Conv2d(64,1,kernel_size = 1, padding = 0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input):
        '''enc1'''
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.conv2(x)
        residual = self.conv3(input)
        skip1 = x+residual

        '''enc 2 and 3'''
        skip2 = self.enc2(skip1)
        skip3 = self.enc3(skip2)

        '''bridge'''
        b = self.bridge(skip3)
        '''decoder'''
        d1 = self.d1(b, skip3)
        d2 = self.d2(d1, skip2)
        d3 = self.d3(d2, skip1)
        '''output'''

        
        output = self.output(d3)
        output = self.sigmoid(output)
        return output