import torch from torch import nn from ...nets.attention_model.context import AutoContext from ...nets.attention_model.dynamic_embedding import AutoDynamicEmbedding from ...nets.attention_model.multi_head_attention import ( AttentionScore, MultiHeadAttention, ) class Decoder(nn.Module): r""" The decoder of the Attention Model. .. math:: \{\log(\pmb{p}_t)\},\pi = \mathrm{Decoder}(s, \pmb{h}) First of all, precompute the keys and values for the embedding :math:`\pmb{h}`: .. math:: \pmb{k}, \pmb{v}, \pmb{k}^\prime = W^K\pmb{h}, W^V\pmb{h}, W^{K^\prime}\pmb{h} and the projection of the graph embedding: .. math:: W_{gc}\bar{\pmb{h}} \quad \text{ for } \bar{\pmb{h}} = \frac{1}{N}\sum\nolimits_i \pmb{h}_i. Then, the decoder iterates the decoding process autoregressively. In each decoding step, we perform multiple attentions to get the logits for each node. .. math:: \begin{aligned} \pmb{h}_{(c)} &= [\bar{\pmb{h}}, \text{Context}(s,\pmb{h})] \\ q & = W^Q \pmb{h}_{(c)} = W_{gc}\bar{\pmb{h}} + W_{sc}\text{Context}(s,\pmb{h}) \\ q_{gl} &= \mathrm{MultiHeadAttention}(q,\pmb{k},\pmb{v},\mathrm{mask}_t) \\ \pmb{p}_t &= \mathrm{Softmax}(\mathrm{AttentionScore}_{\text{clip}}(q_{gl},\pmb{k}^\prime, \mathrm{mask}_t))\\ \pi_{t} &= \mathrm{DecodingStartegy}(\pmb{p}_t) \\ \mathrm{mask}_{t+1} &= \mathrm{mask}_t.update(\pi_t). \end{aligned} .. note:: If there are dynamic node features specified by :mod:`.dynamic_embedding` , the keys and values projections are updated in each decoding step by .. math:: \begin{aligned} \pmb{k}_{\text{dynamic}}, \pmb{v}_{\text{dynamic}}, \pmb{k}^\prime_{\text{dynamic}} &= \mathrm{DynamicEmbedding}(s)\\ \pmb{k} &= \pmb{k} + \pmb{k}_{\text{dynamic}}\\ \pmb{v} &= \pmb{v} +\pmb{v}_{\text{dynamic}} \\ \pmb{k}^\prime &= \pmb{k}^\prime +\pmb{k}^\prime_{\text{dynamic}}. \end{aligned} .. seealso:: * The :math:`\text{Context}` is defined in the :mod:`.context` module. * The :math:`\text{AttentionScore}` is defined by the :class:`.AttentionScore` class. * The :math:`\text{MultiHeadAttention}` is defined by the :class:`.MultiHeadAttention` class. Args: embedding_dim : the dimension of the embedded inputs step_context_dim : the dimension of the context :math:`\text{Context}(\pmb{x})` n_heads: number of heads in the :math:`\mathrm{MultiHeadAttention}` problem: an object defining the state and the mask updating rule of the problem tanh_clipping : the clipping scale of the pointer (attention layer before output) Inputs: input, embeddings * **input** : dict of inputs, for example ``{'loc': tensor, 'depot': tensor, 'demand': tensor}`` for CVRP. * **embeddings**: [batch, graph_size, embedding_dim] Outputs: log_ps, pi * **log_ps**: [batch, graph_size, T] * **pi**: [batch, T] """ def __init__(self, embedding_dim, step_context_dim, n_heads, problem, tanh_clipping): super(Decoder, self).__init__() # For each node we compute (glimpse key, glimpse value, logit key) so 3 * embedding_dim self.project_node_embeddings = nn.Linear(embedding_dim, 3 * embedding_dim, bias=False) self.project_fixed_context = nn.Linear(embedding_dim, embedding_dim, bias=False) self.project_step_context = nn.Linear(step_context_dim, embedding_dim, bias=False) self.context = AutoContext(problem.NAME, {"context_dim": step_context_dim}) self.dynamic_embedding = AutoDynamicEmbedding( problem.NAME, {"embedding_dim": embedding_dim} ) self.glimpse = MultiHeadAttention(embedding_dim=embedding_dim, n_heads=n_heads) self.pointer = AttentionScore(use_tanh=True, C=tanh_clipping) self.decode_type = None self.problem = problem def forward(self, input, embeddings): outputs = [] sequences = [] state = self.problem.make_state(input) # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step cached_embeddings = self._precompute(embeddings) # Perform decoding steps while not (state.all_finished()): log_p, mask = self.advance(cached_embeddings, state) # Select the indices of the next nodes in the sequences, result (batch_size) long # Squeeze out steps dimension action = self.decode(log_p.exp(), mask) state = state.update(action) # Collect output of step outputs.append(log_p) sequences.append(action) # Collected lists, return Tensor return torch.stack(outputs, 1), torch.stack(sequences, 1) def set_decode_type(self, decode_type): r""" Currently support .. code-block:: python ["greedy", "sampling"] """ self.decode_type = decode_type def decode(self, probs, mask): r""" Execute the decoding strategy specified by ``self.decode_type``. Inputs: * **probs**: [batch_size, graph_size] * **mask** (bool): [batch_size, graph_size] Outputs: * **idxs** (int): index of action chosen. [batch_size] """ assert (probs == probs).all(), "Probs should not contain any nans" if self.decode_type == "greedy": _, selected = probs.max(1) assert not mask.gather( 1, selected.unsqueeze(-1) ).data.any(), "Decode greedy: infeasible action has maximum probability" elif self.decode_type == "sampling": selected = probs.multinomial(1).squeeze(1) # Check if sampling went OK, can go wrong due to bug on GPU # See https://discuss.pytorch.org/t/bad-behavior-of-multinomial-function/10232 while mask.gather(1, selected.unsqueeze(-1)).data.any(): print("Sampled bad values, resampling!") selected = probs.multinomial(1).squeeze(1) else: assert False, "Unknown decode type" return selected def _precompute(self, embeddings): # The fixed context projection of the graph embedding is calculated only once for efficiency graph_embed = embeddings.mean(1) # fixed context = (batch_size, 1, embed_dim) to make broadcastable with parallel timesteps graph_context = self.project_fixed_context(graph_embed).unsqueeze(-2) # The projection of the node embeddings for the attention is calculated once up front glimpse_key, glimpse_val, logit_key = self.project_node_embeddings(embeddings).chunk( 3, dim=-1 ) cache = ( embeddings, graph_context, glimpse_key, glimpse_val, logit_key, ) # single head for the final logit return cache def advance(self, cached_embeddings, state): node_embeddings, graph_context, glimpse_K, glimpse_V, logit_K = cached_embeddings # Compute context node embedding: [graph embedding| prev node| problem-state-context] # [batch, 1, context dim] context = self.context(node_embeddings, state) step_context = self.project_step_context(context) # [batch, 1, embed_dim] query = graph_context + step_context # [batch, 1, embed_dim] glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic = self.dynamic_embedding(state) glimpse_K = glimpse_K + glimpse_key_dynamic glimpse_V = glimpse_V + glimpse_val_dynamic logit_K = logit_K + logit_key_dynamic # Compute the mask mask = state.get_mask() # Compute logits (unnormalized log_p) logits, glimpse = self.calc_logits(query, glimpse_K, glimpse_V, logit_K, mask) return logits, glimpse def calc_logits(self, query, glimpse_K, glimpse_V, logit_K, mask): # Compute glimpse with multi-head-attention. # Then use glimpse as a query to compute logits for each node # [batch, 1, embed dim] glimpse = self.glimpse(query, glimpse_K, glimpse_V, mask) logits = self.pointer(glimpse, logit_K, mask) return logits, glimpse