Spaces:
Running
Running
File size: 6,931 Bytes
719d0db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# Node encoder
from models.classifiers.nn_classifiers.encoders.mlp_node_encoder import MLPNodeEncoder
from models.classifiers.nn_classifiers.encoders.mha_node_encoder import SelfMHANodeEncoder
# Edge encoder
from models.classifiers.nn_classifiers.encoders.attn_edge_encoder import AttentionEdgeEncoder
from models.classifiers.nn_classifiers.encoders.concat_edge_encoder import ConcatEdgeEncoder
# Decoder
from models.classifiers.nn_classifiers.decoders.lstm_decoder import LSTMDecoder
from models.classifiers.nn_classifiers.decoders.mlp_decoder import MLPDecoder
from models.classifiers.nn_classifiers.decoders.mha_decoder import SelfMHADecoder
# data loader
from utils.data_utils.tsptw_dataset import load_tsptw_sequentially
from utils.data_utils.pctsp_dataset import load_pctsp_sequentially
from utils.data_utils.pctsptw_dataset import load_pctsptw_sequentially
from utils.data_utils.cvrp_dataset import load_cvrp_sequentially
NODE_ENC_LIST = ["mlp", "mha"]
EDGE_ENC_LIST = ["concat", "attn"]
DEC_LIST = ["mlp", "lstm", "mha"]
class NNClassifier(nn.Module):
def __init__(self,
problem: str,
node_enc_type: str,
edge_enc_type: str,
dec_type: str,
emb_dim: int,
num_enc_mlp_layers: int,
num_dec_mlp_layers: int,
num_classes: int,
dropout: float,
pos_encoder: str = "sincos"):
super().__init__()
self.problem = problem
self.node_enc_type = node_enc_type
self.edge_enc_type = edge_enc_type
self.dec_type = dec_type
assert node_enc_type in NODE_ENC_LIST, f"Invalid enc_type. select from {NODE_ENC_LIST}"
assert dec_type in DEC_LIST, f"Invalid dec_type. select from {DEC_LIST}"
self.is_sequential = True if dec_type in ["lstm", "mha"] else False
coord_dim = 2 # only support 2d problem
if problem == "tsptw":
node_dim = 4 # coords (2) + time window (2)
state_dim = 1 # current time (1)
elif problem == "pctsp":
node_dim = 4 # coords (2) + prize (1) + penalty (1)
state_dim = 2 # current prize (1) + current penalty (1)
elif problem == "pctsptw":
node_dim = 6 # coords (2) + prize (1) + penalty (1) + time window (2)
state_dim = 3 # current prize (1) + current penalty (1) + current time (1)
elif problem == "cvrp":
node_dim = 3 # coords (2) + demand (1)
state_dim = 1 # remaining capacity (1)
else:
NotImplementedError
#----------------
# Graph encoding
#----------------
# Node encoder
if node_enc_type == "mlp":
self.node_enc = MLPNodeEncoder(coord_dim, node_dim, emb_dim, num_enc_mlp_layers, dropout)
elif node_enc_type == "mha":
num_heads = 8
num_mha_layers = 2
self.node_enc = SelfMHANodeEncoder(coord_dim, node_dim, emb_dim, num_heads, num_mha_layers, dropout)
else:
raise NotImplementedError
# Readout
if edge_enc_type == "concat":
self.readout = ConcatEdgeEncoder(state_dim, emb_dim, dropout)
elif edge_enc_type == "attn":
self.readout = AttentionEdgeEncoder(state_dim, emb_dim, dropout)
else:
raise NotImplementedError
#------------------------
# Classification Decoder
#------------------------
if dec_type == "mlp":
self.decoder = MLPDecoder(emb_dim, num_dec_mlp_layers, num_classes, dropout)
elif dec_type == "lstm":
self.decoder = LSTMDecoder(emb_dim, num_dec_mlp_layers, num_classes, dropout)
elif dec_type == "mha":
num_heads = 8
num_mha_layers = 2
self.decoder = SelfMHADecoder(emb_dim, num_heads, num_mha_layers, num_classes, dropout, pos_encoder)
else:
raise NotImplementedError
def forward(self, inputs):
"""
Paramters
---------
inputs: dict
curr_node_id: torch.LongTensor [batch_size x max_seq_length] if self.sequential else [batch_size]
next_node_id: torch.LongTensor [batch_size x max_seq_length] if self.sequential else [batch_size]
node_feat: torch.FloatTensor [batch_size x max_seq_length x num_nodes x node_dim] if self.sequential else [batch_size x num_nodes x node_dim]
mask: torch.LongTensor [batch_size x max_seq_length x num_nodes] if self.sequential else [batch_size x num_nodes]
state: torch.FloatTensor [batch_size x max_seq_length x state_dim] if self.sequential else [batch_size x state_dim]
Returns
-------
probs: torch.tensor [batch_size x seq_length x num_classes] if self.sequential else [batch_size x num_classes]
probabilities of classes
"""
#-----------------
# Encoding graphs
#-----------------
if self.is_sequential:
shp = inputs["curr_node_id"].size()
inputs = {key: value.flatten(0, 1) for key, value in inputs.items()}
node_emb = self.node_enc(inputs) # [(batch_size*max_seq_length) x emb_dim] if self.sequential else [batch_size x emb_dim]
graph_emb = self.readout(inputs, node_emb)
if self.is_sequential:
graph_emb = graph_emb.view(*shp, -1) # [batch_size x max_seq_length x emb_dim]
#----------
# Decoding
#----------
probs = self.decoder(graph_emb)
return probs
def get_inputs(self, routes, first_explained_step, node_feats):
node_feats_ = node_feats.copy()
node_feats_["tour"] = routes
if self.problem == "tsptw":
seq_data = load_tsptw_sequentially(node_feats_)
elif self.problem == "pctsp":
seq_data = load_pctsp_sequentially(node_feats_)
elif self.problem == "pctsptw":
seq_data = load_pctsptw_sequentially(node_feats_)
elif self.problem == "cvrp":
seq_data = load_cvrp_sequentially(node_feats_)
else:
NotImplementedError
def pad_seq_length(batch):
data = {}
for key in batch[0].keys():
padding_value = True if key == "mask" else 0.0
# post-padding
data[key] = torch.nn.utils.rnn.pad_sequence([d[key] for d in batch], batch_first=True, padding_value=padding_value)
pad_mask = torch.nn.utils.rnn.pad_sequence([torch.full((d["mask"].size(0), ), True) for d in batch], batch_first=True, padding_value=False)
data.update({"pad_mask": pad_mask})
return data
instance = pad_seq_length(seq_data)
return instance |