daisuke.kikuta
first commit
719d0db
import math
import torch
import torch.nn as nn
class SelfMHANodeEncoder(nn.Module):
def __init__(self, coord_dim, node_dim, emb_dim, num_heads, num_mha_layers, dropout):
super().__init__()
self.coord_dim = coord_dim
self.node_dim = node_dim
self.emb_dim = emb_dim
self.num_mha_layers = num_mha_layers
# initial embedding
self.init_linear_nodes = nn.Linear(node_dim, emb_dim)
self.init_linear_depot = nn.Linear(coord_dim, emb_dim)
# MHA Encoder (w/o positional encoding)
mha_layer = nn.TransformerEncoderLayer(d_model=emb_dim,
nhead=num_heads,
dim_feedforward=emb_dim,
dropout=dropout,
batch_first=True)
self.mha = nn.TransformerEncoder(mha_layer, num_layers=num_mha_layers)
# Initializing weights
self.reset_parameters()
def reset_parameters(self):
for param in self.parameters():
stdv = 1. / math.sqrt(param.size(-1))
param.data.uniform_(-stdv, stdv)
def forward(self, inputs):
"""
Paramters
---------
inputs: dict
node_feat: torch.FloatTensor [batch_size x num_nodes x node_dim]
Returns
-------
node_emb: torch.tensor [batch_size x num_nodes x emb_dim]
node embeddings
"""
#----------------
# input features
#----------------
node_feat = inputs["node_feats"]
#------------------------------------------------------------------------
# initial linear projection for adjusting dimensions of locs & the depot
#------------------------------------------------------------------------
# node_feat = self.dropout(node_feat)
loc_emb = self.init_linear_nodes(node_feat[:, 1:, :]) # [batch_size x num_loc x emb_dim]
depot_emb = self.init_linear_depot(node_feat[:, 0:1, :2]) # [batch_size x 1 x emb_dim]
node_emb = torch.cat((depot_emb, loc_emb), 1) # [batch_size x num_nodes x emb_dim]
#--------------
# MLP encoding
#--------------
node_emb = self.mha(node_emb) # [batch_size x num_nodes x emb_dim]
return node_emb