daisuke.kikuta
first commit
719d0db
import math
import torch
import torch.nn as nn
class MLPNodeEncoder(nn.Module):
def __init__(self, coord_dim, node_dim, emb_dim, num_mlp_layers, dropout):
super().__init__()
self.coord_dim = coord_dim
self.node_dim = node_dim
self.emb_dim = emb_dim
self.num_mlp_layers = num_mlp_layers
# initial embedding
self.init_linear_nodes = nn.Linear(node_dim, emb_dim)
self.init_linear_depot = nn.Linear(coord_dim, emb_dim)
# MLP Encoder
self.mlp = nn.ModuleList()
for _ in range(num_mlp_layers):
self.mlp.append(nn.Linear(emb_dim, emb_dim, bias=True))
# Dropout
self.dropout = nn.Dropout(dropout)
def reset_parameters(self):
for param in self.parameters():
stdv = 1. / math.sqrt(param.size(-1))
param.data.uniform_(-stdv, stdv)
def forward(self, inputs):
"""
Paramters
---------
inputs: dict
node_feat: torch.FloatTensor [batch_size x num_nodes x node_dim]
Returns
-------
node_emb: torch.tensor [batch_size x num_nodes x emb_dim]
node embeddings
"""
#----------------
# input features
#----------------
node_feat = inputs["node_feats"]
#------------------------------------------------------------------------
# initial linear projection for adjusting dimensions of locs & the depot
#------------------------------------------------------------------------
# node_feat = self.dropout(node_feat)
loc_emb = self.init_linear_nodes(node_feat[:, 1:, :]) # [batch_size x num_loc x emb_dim]
depot_emb = self.init_linear_depot(node_feat[:, 0:1, :2]) # [batch_size x 1 x emb_dim]
node_emb = torch.cat((depot_emb, loc_emb), 1) # [batch_size x num_nodes x emb_dim]
#--------------
# MLP encoding
#--------------
for i in range(self.num_mlp_layers):
# node_emb = self.dropout(node_emb)
node_emb = self.mlp[i](node_emb)
if i != self.num_mlp_layers - 1:
node_emb = torch.relu(node_emb)
return node_emb