File size: 1,575 Bytes
d015578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch.nn as nn

from spiga.models.cnn.layers import Conv, Deconv, Residual


class Hourglass(nn.Module):
    def __init__(self, n, f, bn=None, increase=0):
        super(Hourglass, self).__init__()
        nf = f + increase
        self.up1 = Residual(f, f)
        # Lower branch
        self.pool1 = Conv(f, f, 2, 2,  bn=True, relu=True)
        self.low1 = Residual(f, nf)
        self.n = n
        # Recursive hourglass
        if self.n > 1:
            self.low2 = Hourglass(n - 1, nf, bn=bn)
        else:
            self.low2 = Residual(nf, nf)
        self.low3 = Residual(nf, f)
        self.up2 = Deconv(f, f, 2, 2, bn=True, relu=True)

    def forward(self, x):
        up1 = self.up1(x)
        pool1 = self.pool1(x)
        low1 = self.low1(pool1)
        low2 = self.low2(low1)
        low3 = self.low3(low2)
        up2 = self.up2(low3)
        return up1 + up2


class HourglassCore(Hourglass):
    def __init__(self, n, f, bn=None, increase=0):
        super(HourglassCore, self).__init__(n, f, bn=bn, increase=increase)
        nf = f + increase
        if self.n > 1:
            self.low2 = HourglassCore(n - 1, nf, bn=bn)

    def forward(self, x, core=[]):
        up1 = self.up1(x)
        pool1 = self.pool1(x)
        low1 = self.low1(pool1)
        if self.n > 1:
            low2, core = self.low2(low1, core=core)
        else:
            low2 = self.low2(low1)
            core.append(low2)
        low3 = self.low3(low2)
        if self.n > 1:
            core.append(low3)
        up2 = self.up2(low3)
        return up1 + up2, core