Spaces:
Running
Running
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
# Node encoder | |
from models.classifiers.nn_classifiers.encoders.mlp_node_encoder import MLPNodeEncoder | |
from models.classifiers.nn_classifiers.encoders.mha_node_encoder import SelfMHANodeEncoder | |
# Edge encoder | |
from models.classifiers.nn_classifiers.encoders.attn_edge_encoder import AttentionEdgeEncoder | |
from models.classifiers.nn_classifiers.encoders.concat_edge_encoder import ConcatEdgeEncoder | |
# Decoder | |
from models.classifiers.nn_classifiers.decoders.lstm_decoder import LSTMDecoder | |
from models.classifiers.nn_classifiers.decoders.mlp_decoder import MLPDecoder | |
from models.classifiers.nn_classifiers.decoders.mha_decoder import SelfMHADecoder | |
# data loader | |
from utils.data_utils.tsptw_dataset import load_tsptw_sequentially | |
from utils.data_utils.pctsp_dataset import load_pctsp_sequentially | |
from utils.data_utils.pctsptw_dataset import load_pctsptw_sequentially | |
from utils.data_utils.cvrp_dataset import load_cvrp_sequentially | |
NODE_ENC_LIST = ["mlp", "mha"] | |
EDGE_ENC_LIST = ["concat", "attn"] | |
DEC_LIST = ["mlp", "lstm", "mha"] | |
class NNClassifier(nn.Module): | |
def __init__(self, | |
problem: str, | |
node_enc_type: str, | |
edge_enc_type: str, | |
dec_type: str, | |
emb_dim: int, | |
num_enc_mlp_layers: int, | |
num_dec_mlp_layers: int, | |
num_classes: int, | |
dropout: float, | |
pos_encoder: str = "sincos"): | |
super().__init__() | |
self.problem = problem | |
self.node_enc_type = node_enc_type | |
self.edge_enc_type = edge_enc_type | |
self.dec_type = dec_type | |
assert node_enc_type in NODE_ENC_LIST, f"Invalid enc_type. select from {NODE_ENC_LIST}" | |
assert dec_type in DEC_LIST, f"Invalid dec_type. select from {DEC_LIST}" | |
self.is_sequential = True if dec_type in ["lstm", "mha"] else False | |
coord_dim = 2 # only support 2d problem | |
if problem == "tsptw": | |
node_dim = 4 # coords (2) + time window (2) | |
state_dim = 1 # current time (1) | |
elif problem == "pctsp": | |
node_dim = 4 # coords (2) + prize (1) + penalty (1) | |
state_dim = 2 # current prize (1) + current penalty (1) | |
elif problem == "pctsptw": | |
node_dim = 6 # coords (2) + prize (1) + penalty (1) + time window (2) | |
state_dim = 3 # current prize (1) + current penalty (1) + current time (1) | |
elif problem == "cvrp": | |
node_dim = 3 # coords (2) + demand (1) | |
state_dim = 1 # remaining capacity (1) | |
else: | |
NotImplementedError | |
#---------------- | |
# Graph encoding | |
#---------------- | |
# Node encoder | |
if node_enc_type == "mlp": | |
self.node_enc = MLPNodeEncoder(coord_dim, node_dim, emb_dim, num_enc_mlp_layers, dropout) | |
elif node_enc_type == "mha": | |
num_heads = 8 | |
num_mha_layers = 2 | |
self.node_enc = SelfMHANodeEncoder(coord_dim, node_dim, emb_dim, num_heads, num_mha_layers, dropout) | |
else: | |
raise NotImplementedError | |
# Readout | |
if edge_enc_type == "concat": | |
self.readout = ConcatEdgeEncoder(state_dim, emb_dim, dropout) | |
elif edge_enc_type == "attn": | |
self.readout = AttentionEdgeEncoder(state_dim, emb_dim, dropout) | |
else: | |
raise NotImplementedError | |
#------------------------ | |
# Classification Decoder | |
#------------------------ | |
if dec_type == "mlp": | |
self.decoder = MLPDecoder(emb_dim, num_dec_mlp_layers, num_classes, dropout) | |
elif dec_type == "lstm": | |
self.decoder = LSTMDecoder(emb_dim, num_dec_mlp_layers, num_classes, dropout) | |
elif dec_type == "mha": | |
num_heads = 8 | |
num_mha_layers = 2 | |
self.decoder = SelfMHADecoder(emb_dim, num_heads, num_mha_layers, num_classes, dropout, pos_encoder) | |
else: | |
raise NotImplementedError | |
def forward(self, inputs): | |
""" | |
Paramters | |
--------- | |
inputs: dict | |
curr_node_id: torch.LongTensor [batch_size x max_seq_length] if self.sequential else [batch_size] | |
next_node_id: torch.LongTensor [batch_size x max_seq_length] if self.sequential else [batch_size] | |
node_feat: torch.FloatTensor [batch_size x max_seq_length x num_nodes x node_dim] if self.sequential else [batch_size x num_nodes x node_dim] | |
mask: torch.LongTensor [batch_size x max_seq_length x num_nodes] if self.sequential else [batch_size x num_nodes] | |
state: torch.FloatTensor [batch_size x max_seq_length x state_dim] if self.sequential else [batch_size x state_dim] | |
Returns | |
------- | |
probs: torch.tensor [batch_size x seq_length x num_classes] if self.sequential else [batch_size x num_classes] | |
probabilities of classes | |
""" | |
#----------------- | |
# Encoding graphs | |
#----------------- | |
if self.is_sequential: | |
shp = inputs["curr_node_id"].size() | |
inputs = {key: value.flatten(0, 1) for key, value in inputs.items()} | |
node_emb = self.node_enc(inputs) # [(batch_size*max_seq_length) x emb_dim] if self.sequential else [batch_size x emb_dim] | |
graph_emb = self.readout(inputs, node_emb) | |
if self.is_sequential: | |
graph_emb = graph_emb.view(*shp, -1) # [batch_size x max_seq_length x emb_dim] | |
#---------- | |
# Decoding | |
#---------- | |
probs = self.decoder(graph_emb) | |
return probs | |
def get_inputs(self, routes, first_explained_step, node_feats): | |
node_feats_ = node_feats.copy() | |
node_feats_["tour"] = routes | |
if self.problem == "tsptw": | |
seq_data = load_tsptw_sequentially(node_feats_) | |
elif self.problem == "pctsp": | |
seq_data = load_pctsp_sequentially(node_feats_) | |
elif self.problem == "pctsptw": | |
seq_data = load_pctsptw_sequentially(node_feats_) | |
elif self.problem == "cvrp": | |
seq_data = load_cvrp_sequentially(node_feats_) | |
else: | |
NotImplementedError | |
def pad_seq_length(batch): | |
data = {} | |
for key in batch[0].keys(): | |
padding_value = True if key == "mask" else 0.0 | |
# post-padding | |
data[key] = torch.nn.utils.rnn.pad_sequence([d[key] for d in batch], batch_first=True, padding_value=padding_value) | |
pad_mask = torch.nn.utils.rnn.pad_sequence([torch.full((d["mask"].size(0), ), True) for d in batch], batch_first=True, padding_value=False) | |
data.update({"pad_mask": pad_mask}) | |
return data | |
instance = pad_seq_length(seq_data) | |
return instance |