""" Problem specific global embedding for global context. """ import torch from torch import nn def AutoContext(problem_name, config): """ Automatically select the corresponding module according to ``problem_name`` """ mapping = { "tsp": TSPContext, "cvrp": VRPContext, "sdvrp": VRPContext, "pctsp": PCTSPContext, "op": OPContext, } embeddingClass = mapping[problem_name] embedding = embeddingClass(**config) return embedding def _gather_by_index(source, index): """ target[i,1,:] = source[i,index[i],:] Inputs: source: [B x H x D] index: [B x 1] or [B] Outpus: target: [B x 1 x D] """ target = torch.gather(source, 1, index.unsqueeze(-1).expand(-1, -1, source.size(-1))) return target class PrevNodeContext(nn.Module): """ Abstract class for Context. Any subclass, by default, will return a concatenation of +---------------------+-----------------+ | prev_node_embedding | state_embedding | +---------------------+-----------------+ The ``prev_node_embedding`` is the node embedding of the last visited node. It is obtained by ``_prev_node_embedding`` method. It requires ``state.get_current_node()`` to provide the index of the last visited node. The ``state_embedding`` is the global context we want to include, such as the remaining capacity in VRP. It is obtained by ``_state_embedding`` method. It is not implemented. The subclass of this abstract class needs to implement this method. Args: problem: an object defining the settings of the environment context_dim: the dimension of the output Inputs: embeddings, state * **embeddings** : [batch x graph size x embed dim] * **state**: An object providing observations in the environment. \ Needs to supply ``state.get_current_node()`` Outputs: context_embedding * **context_embedding**: [batch x 1 x context_dim] """ def __init__(self, context_dim): super(PrevNodeContext, self).__init__() self.context_dim = context_dim def _prev_node_embedding(self, embeddings, state): current_node = state.get_current_node() prev_node_embedding = _gather_by_index(embeddings, current_node) return prev_node_embedding def _state_embedding(self, embeddings, state): raise NotImplementedError("Please implement the embedding for your own problem.") def forward(self, embeddings, state): prev_node_embedding = self._prev_node_embedding(embeddings, state) state_embedding = self._state_embedding(embeddings, state) # Embedding of previous node + remaining capacity context_embedding = torch.cat((prev_node_embedding, state_embedding), -1) return context_embedding class TSPContext(PrevNodeContext): """ Context node embedding for traveling salesman problem. Return a concatenation of +------------------------+---------------------------+ | first node's embedding | previous node's embedding | +------------------------+---------------------------+ .. note:: Subclass of :class:`.PrevNodeContext`. The argument, inputs, outputs follow the same specification. In addition to supplying ``state.get_current_node()`` for the index of the previous visited node. The input ``state`` needs to supply ``state.first_a`` for the index of the first visited node. .. warning:: The official implementation concatenates the context with [first node, prev node]. However, if we follow the paper closely, it should instead be [prev node, first node]. Please check ``forward_code`` and ``forward_paper`` for the different implementations. We follow the official implementation in this class. """ def __init__(self, context_dim): super(TSPContext, self).__init__(context_dim) self.W_placeholder = nn.Parameter(torch.Tensor(self.context_dim).uniform_(-1, 1)) def _state_embedding(self, embeddings, state): first_node = state.first_a state_embedding = _gather_by_index(embeddings, first_node) return state_embedding def forward_paper(self, embeddings, state): batch_size = embeddings.size(0) if state.i.item() == 0: context_embedding = self.W_placeholder[None, None, :].expand( batch_size, 1, self.W_placeholder.size(-1) ) else: context_embedding = super().forward(embeddings, state) return context_embedding def forward_code(self, embeddings, state): batch_size = embeddings.size(0) if state.i.item() == 0: context_embedding = self.W_placeholder[None, None, :].expand( batch_size, 1, self.W_placeholder.size(-1) ) else: context_embedding = _gather_by_index( embeddings, torch.cat([state.first_a, state.get_current_node()], -1) ).view(batch_size, 1, -1) return context_embedding def forward_vectorized(self, embeddings, state): n_queries = state.states["first_node_idx"].shape[-1] batch_size = embeddings.size(0) out_shape = (batch_size, n_queries, self.context_dim) switch = state.is_initial_action # tensor, 1 if is initial action switch = switch[:, None, None].expand(out_shape) # mask for each data # only used for the first action placeholder_embedding = self.W_placeholder[None, None, :].expand(out_shape) # used after first action indexes = torch.stack([state.first_a, state.get_current_node()], -1).flatten(-2) normal_embedding = _gather_by_index(embeddings, indexes).view(out_shape) context_embedding = switch * placeholder_embedding + (~switch) * normal_embedding return context_embedding def forward(self, embeddings, state): return self.forward_vectorized(embeddings, state) class VRPContext(PrevNodeContext): """ Context node embedding for capacitated vehicle routing problem. Return a concatenation of +---------------------------+--------------------+ | previous node's embedding | remaining capacity | +---------------------------+--------------------+ .. note:: Subclass of :class:`.PrevNodeContext`. The argument, inputs, outputs follow the same specification. In addition to supplying ``state.get_current_node()`` for the index of the previous visited node. The input ``state`` needs to supply ``state.VEHICLE_CAPACITY`` and ``state.used_capacity`` for calculating the remaining capcacity. """ def __init__(self, context_dim): super(VRPContext, self).__init__(context_dim) def _state_embedding(self, embeddings, state): state_embedding = state.VEHICLE_CAPACITY - state.used_capacity[:, :, None] return state_embedding class PCTSPContext(PrevNodeContext): """ Context node embedding for prize collecting traveling salesman problem. Return a concatenation of +---------------------------+----------------------------+ | previous node's embedding | remaining prize to collect | +---------------------------+----------------------------+ .. note:: Subclass of :class:`.PrevNodeContext`. The argument, inputs, outputs follow the same specification. In addition to supplying ``state.get_current_node()`` for the index of the previous visited node. The input ``state`` needs to supply ``state.get_remaining_prize_to_collect()``. """ def __init__(self, context_dim): super(PCTSPContext, self).__init__(context_dim) def _state_embedding(self, embeddings, state): state_embedding = state.get_remaining_prize_to_collect()[:, :, None] return state_embedding class OPContext(PrevNodeContext): """ Context node embedding for orienteering problem. Return a concatenation of +---------------------------+---------------------------------+ | previous node's embedding | remaining tour length to travel | +---------------------------+---------------------------------+ .. note:: Subclass of :class:`.PrevNodeContext`. The argument, inputs, outputs follow the same specification. In addition to supplying ``state.get_current_node()`` for the index of the previous visited node. The input ``state`` needs to supply ``state.get_remaining_length()``. """ def __init__(self, context_dim): super(OPContext, self).__init__(context_dim) def _state_embedding(self, embeddings, state): state_embedding = state.get_remaining_length()[:, :, None] return state_embedding