Spaces:
Running
Running
import math | |
import torch | |
import torch.nn as nn | |
class SelfMHANodeEncoder(nn.Module): | |
def __init__(self, coord_dim, node_dim, emb_dim, num_heads, num_mha_layers, dropout): | |
super().__init__() | |
self.coord_dim = coord_dim | |
self.node_dim = node_dim | |
self.emb_dim = emb_dim | |
self.num_mha_layers = num_mha_layers | |
# initial embedding | |
self.init_linear_nodes = nn.Linear(node_dim, emb_dim) | |
self.init_linear_depot = nn.Linear(coord_dim, emb_dim) | |
# MHA Encoder (w/o positional encoding) | |
mha_layer = nn.TransformerEncoderLayer(d_model=emb_dim, | |
nhead=num_heads, | |
dim_feedforward=emb_dim, | |
dropout=dropout, | |
batch_first=True) | |
self.mha = nn.TransformerEncoder(mha_layer, num_layers=num_mha_layers) | |
# Initializing weights | |
self.reset_parameters() | |
def reset_parameters(self): | |
for param in self.parameters(): | |
stdv = 1. / math.sqrt(param.size(-1)) | |
param.data.uniform_(-stdv, stdv) | |
def forward(self, inputs): | |
""" | |
Paramters | |
--------- | |
inputs: dict | |
node_feat: torch.FloatTensor [batch_size x num_nodes x node_dim] | |
Returns | |
------- | |
node_emb: torch.tensor [batch_size x num_nodes x emb_dim] | |
node embeddings | |
""" | |
#---------------- | |
# input features | |
#---------------- | |
node_feat = inputs["node_feats"] | |
#------------------------------------------------------------------------ | |
# initial linear projection for adjusting dimensions of locs & the depot | |
#------------------------------------------------------------------------ | |
# node_feat = self.dropout(node_feat) | |
loc_emb = self.init_linear_nodes(node_feat[:, 1:, :]) # [batch_size x num_loc x emb_dim] | |
depot_emb = self.init_linear_depot(node_feat[:, 0:1, :2]) # [batch_size x 1 x emb_dim] | |
node_emb = torch.cat((depot_emb, loc_emb), 1) # [batch_size x num_nodes x emb_dim] | |
#-------------- | |
# MLP encoding | |
#-------------- | |
node_emb = self.mha(node_emb) # [batch_size x num_nodes x emb_dim] | |
return node_emb |