File size: 2,223 Bytes
719d0db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import math
import torch
import torch.nn as nn

class MLPNodeEncoder(nn.Module):
    def __init__(self, coord_dim, node_dim, emb_dim, num_mlp_layers, dropout):
        super().__init__()
        self.coord_dim = coord_dim
        self.node_dim  = node_dim 
        self.emb_dim   = emb_dim
        self.num_mlp_layers = num_mlp_layers

        # initial embedding
        self.init_linear_nodes = nn.Linear(node_dim, emb_dim)
        self.init_linear_depot = nn.Linear(coord_dim, emb_dim)

        # MLP Encoder
        self.mlp = nn.ModuleList()
        for _ in range(num_mlp_layers):
            self.mlp.append(nn.Linear(emb_dim, emb_dim, bias=True))

        # Dropout
        self.dropout = nn.Dropout(dropout)

    def reset_parameters(self):
        for param in self.parameters():
            stdv = 1. / math.sqrt(param.size(-1))
            param.data.uniform_(-stdv, stdv)

    def forward(self, inputs):
        """
        Paramters
        ---------
        inputs: dict
            node_feat: torch.FloatTensor [batch_size x num_nodes x node_dim]

        Returns
        -------
        node_emb: torch.tensor [batch_size x num_nodes x emb_dim]
            node embeddings 
        """
        #----------------
        # input features
        #----------------
        node_feat = inputs["node_feats"]

        #------------------------------------------------------------------------
        # initial linear projection for adjusting dimensions of locs & the depot
        #------------------------------------------------------------------------
        # node_feat = self.dropout(node_feat)
        loc_emb  = self.init_linear_nodes(node_feat[:, 1:, :]) # [batch_size x num_loc x emb_dim]
        depot_emb = self.init_linear_depot(node_feat[:, 0:1, :2]) # [batch_size x 1 x emb_dim]
        node_emb = torch.cat((depot_emb, loc_emb), 1) # [batch_size x num_nodes x emb_dim]

        #--------------
        # MLP encoding
        #--------------
        for i in range(self.num_mlp_layers):
            # node_emb = self.dropout(node_emb)
            node_emb = self.mlp[i](node_emb)
            if i != self.num_mlp_layers - 1:
                node_emb = torch.relu(node_emb)
        return node_emb