File size: 2,377 Bytes
719d0db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import math
import torch
import torch.nn as nn

class SelfMHANodeEncoder(nn.Module):
    def __init__(self, coord_dim, node_dim, emb_dim, num_heads, num_mha_layers, dropout):
        super().__init__()
        self.coord_dim = coord_dim
        self.node_dim  = node_dim 
        self.emb_dim   = emb_dim
        self.num_mha_layers = num_mha_layers

        # initial embedding
        self.init_linear_nodes = nn.Linear(node_dim, emb_dim)
        self.init_linear_depot = nn.Linear(coord_dim, emb_dim)

        # MHA Encoder (w/o positional encoding)
        mha_layer = nn.TransformerEncoderLayer(d_model=emb_dim, 
                                               nhead=num_heads,
                                               dim_feedforward=emb_dim,
                                               dropout=dropout,
                                               batch_first=True)
        self.mha = nn.TransformerEncoder(mha_layer, num_layers=num_mha_layers)

        # Initializing weights
        self.reset_parameters()

    def reset_parameters(self):
        for param in self.parameters():
            stdv = 1. / math.sqrt(param.size(-1))
            param.data.uniform_(-stdv, stdv)

    def forward(self, inputs):
        """
        Paramters
        ---------
        inputs: dict
            node_feat: torch.FloatTensor [batch_size x num_nodes x node_dim]

        Returns
        -------
        node_emb: torch.tensor [batch_size x num_nodes x emb_dim]
            node embeddings 
        """
        #----------------
        # input features
        #----------------
        node_feat = inputs["node_feats"]

        #------------------------------------------------------------------------
        # initial linear projection for adjusting dimensions of locs & the depot
        #------------------------------------------------------------------------
        # node_feat = self.dropout(node_feat)
        loc_emb  = self.init_linear_nodes(node_feat[:, 1:, :]) # [batch_size x num_loc x emb_dim]
        depot_emb = self.init_linear_depot(node_feat[:, 0:1, :2]) # [batch_size x 1 x emb_dim]
        node_emb = torch.cat((depot_emb, loc_emb), 1) # [batch_size x num_nodes x emb_dim]

        #--------------
        # MLP encoding
        #--------------
        node_emb = self.mha(node_emb) # [batch_size x num_nodes x emb_dim]

        return node_emb