File size: 1,423 Bytes
d015578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch.nn as nn

from spiga.models.gnn.layers import MLP
from spiga.models.gnn.gat import GAT


class StepRegressor(nn.Module):

    def __init__(self, input_dim: int, feature_dim: int, nstack=4, decoding=[256, 128, 64, 32]):
        super(StepRegressor, self).__init__()
        assert nstack > 0
        self.nstack = nstack
        self.gat = nn.ModuleList([GAT(input_dim, feature_dim, 4)])
        for _ in range(nstack-1):
            self.gat.append(GAT(feature_dim, feature_dim, 4))
        self.decoder = OffsetDecoder(feature_dim, decoding)

    def forward(self, embedded, prob_list=[]):
        embedded = embedded.transpose(-1, -2)
        for i in range(self.nstack):
            embedded, prob = self.gat[i](embedded)
            prob_list.append(prob)
        offset = self.decoder(embedded)
        return offset.transpose(-1, -2), prob_list


class OffsetDecoder(nn.Module):
    def __init__(self, feature_dim, layers):
        super().__init__()
        self.decoder = MLP([feature_dim] + layers + [2])

    def forward(self, embedded):
        return self.decoder(embedded)


class RelativePositionEncoder(nn.Module):
    def __init__(self, input_dim, feature_dim, layers):
        super().__init__()
        self.encoder = MLP([input_dim] + layers + [feature_dim])

    def forward(self, feature):
        feature = feature.transpose(-1, -2)
        return self.encoder(feature).transpose(-1, -2)