File size: 8,816 Bytes
52933b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
"""
Problem specific global embedding for global context.
"""
import torch
from torch import nn
def AutoContext(problem_name, config):
"""
Automatically select the corresponding module according to ``problem_name``
"""
mapping = {
"tsp": TSPContext,
"cvrp": VRPContext,
"sdvrp": VRPContext,
"pctsp": PCTSPContext,
"op": OPContext,
}
embeddingClass = mapping[problem_name]
embedding = embeddingClass(**config)
return embedding
def _gather_by_index(source, index):
"""
target[i,1,:] = source[i,index[i],:]
Inputs:
source: [B x H x D]
index: [B x 1] or [B]
Outpus:
target: [B x 1 x D]
"""
target = torch.gather(source, 1, index.unsqueeze(-1).expand(-1, -1, source.size(-1)))
return target
class PrevNodeContext(nn.Module):
"""
Abstract class for Context.
Any subclass, by default, will return a concatenation of
+---------------------+-----------------+
| prev_node_embedding | state_embedding |
+---------------------+-----------------+
The ``prev_node_embedding`` is the node embedding of the last visited node.
It is obtained by ``_prev_node_embedding`` method.
It requires ``state.get_current_node()`` to provide the index of the last visited node.
The ``state_embedding`` is the global context we want to include, such as the remaining capacity in VRP.
It is obtained by ``_state_embedding`` method.
It is not implemented. The subclass of this abstract class needs to implement this method.
Args:
problem: an object defining the settings of the environment
context_dim: the dimension of the output
Inputs: embeddings, state
* **embeddings** : [batch x graph size x embed dim]
* **state**: An object providing observations in the environment. \
Needs to supply ``state.get_current_node()``
Outputs: context_embedding
* **context_embedding**: [batch x 1 x context_dim]
"""
def __init__(self, context_dim):
super(PrevNodeContext, self).__init__()
self.context_dim = context_dim
def _prev_node_embedding(self, embeddings, state):
current_node = state.get_current_node()
prev_node_embedding = _gather_by_index(embeddings, current_node)
return prev_node_embedding
def _state_embedding(self, embeddings, state):
raise NotImplementedError("Please implement the embedding for your own problem.")
def forward(self, embeddings, state):
prev_node_embedding = self._prev_node_embedding(embeddings, state)
state_embedding = self._state_embedding(embeddings, state)
# Embedding of previous node + remaining capacity
context_embedding = torch.cat((prev_node_embedding, state_embedding), -1)
return context_embedding
class TSPContext(PrevNodeContext):
"""
Context node embedding for traveling salesman problem.
Return a concatenation of
+------------------------+---------------------------+
| first node's embedding | previous node's embedding |
+------------------------+---------------------------+
.. note::
Subclass of :class:`.PrevNodeContext`. The argument, inputs, outputs follow the same specification.
In addition to supplying ``state.get_current_node()`` for the index of the previous visited node.
The input ``state`` needs to supply ``state.first_a`` for the index of the first visited node.
.. warning::
The official implementation concatenates the context with [first node, prev node].
However, if we follow the paper closely, it should instead be [prev node, first node].
Please check ``forward_code`` and ``forward_paper`` for the different implementations.
We follow the official implementation in this class.
"""
def __init__(self, context_dim):
super(TSPContext, self).__init__(context_dim)
self.W_placeholder = nn.Parameter(torch.Tensor(self.context_dim).uniform_(-1, 1))
def _state_embedding(self, embeddings, state):
first_node = state.first_a
state_embedding = _gather_by_index(embeddings, first_node)
return state_embedding
def forward_paper(self, embeddings, state):
batch_size = embeddings.size(0)
if state.i.item() == 0:
context_embedding = self.W_placeholder[None, None, :].expand(
batch_size, 1, self.W_placeholder.size(-1)
)
else:
context_embedding = super().forward(embeddings, state)
return context_embedding
def forward_code(self, embeddings, state):
batch_size = embeddings.size(0)
if state.i.item() == 0:
context_embedding = self.W_placeholder[None, None, :].expand(
batch_size, 1, self.W_placeholder.size(-1)
)
else:
context_embedding = _gather_by_index(
embeddings, torch.cat([state.first_a, state.get_current_node()], -1)
).view(batch_size, 1, -1)
return context_embedding
def forward_vectorized(self, embeddings, state):
n_queries = state.states["first_node_idx"].shape[-1]
batch_size = embeddings.size(0)
out_shape = (batch_size, n_queries, self.context_dim)
switch = state.is_initial_action # tensor, 1 if is initial action
switch = switch[:, None, None].expand(out_shape) # mask for each data
# only used for the first action
placeholder_embedding = self.W_placeholder[None, None, :].expand(out_shape)
# used after first action
indexes = torch.stack([state.first_a, state.get_current_node()], -1).flatten(-2)
normal_embedding = _gather_by_index(embeddings, indexes).view(out_shape)
context_embedding = switch * placeholder_embedding + (~switch) * normal_embedding
return context_embedding
def forward(self, embeddings, state):
return self.forward_vectorized(embeddings, state)
class VRPContext(PrevNodeContext):
"""
Context node embedding for capacitated vehicle routing problem.
Return a concatenation of
+---------------------------+--------------------+
| previous node's embedding | remaining capacity |
+---------------------------+--------------------+
.. note::
Subclass of :class:`.PrevNodeContext`. The argument, inputs, outputs follow the same specification.
In addition to supplying ``state.get_current_node()`` for the index of the previous visited node.
The input ``state`` needs to supply ``state.VEHICLE_CAPACITY`` and ``state.used_capacity``
for calculating the remaining capcacity.
"""
def __init__(self, context_dim):
super(VRPContext, self).__init__(context_dim)
def _state_embedding(self, embeddings, state):
state_embedding = state.VEHICLE_CAPACITY - state.used_capacity[:, :, None]
return state_embedding
class PCTSPContext(PrevNodeContext):
"""
Context node embedding for prize collecting traveling salesman problem.
Return a concatenation of
+---------------------------+----------------------------+
| previous node's embedding | remaining prize to collect |
+---------------------------+----------------------------+
.. note::
Subclass of :class:`.PrevNodeContext`. The argument, inputs, outputs follow the same specification.
In addition to supplying ``state.get_current_node()`` for the index of the previous visited node.
The input ``state`` needs to supply ``state.get_remaining_prize_to_collect()``.
"""
def __init__(self, context_dim):
super(PCTSPContext, self).__init__(context_dim)
def _state_embedding(self, embeddings, state):
state_embedding = state.get_remaining_prize_to_collect()[:, :, None]
return state_embedding
class OPContext(PrevNodeContext):
"""
Context node embedding for orienteering problem.
Return a concatenation of
+---------------------------+---------------------------------+
| previous node's embedding | remaining tour length to travel |
+---------------------------+---------------------------------+
.. note::
Subclass of :class:`.PrevNodeContext`. The argument, inputs, outputs follow the same specification.
In addition to supplying ``state.get_current_node()`` for the index of the previous visited node.
The input ``state`` needs to supply ``state.get_remaining_length()``.
"""
def __init__(self, context_dim):
super(OPContext, self).__init__(context_dim)
def _state_embedding(self, embeddings, state):
state_embedding = state.get_remaining_length()[:, :, None]
return state_embedding
|