File size: 8,617 Bytes
52933b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import torch
from torch import nn

from ...nets.attention_model.context import AutoContext
from ...nets.attention_model.dynamic_embedding import AutoDynamicEmbedding
from ...nets.attention_model.multi_head_attention import (
    AttentionScore,
    MultiHeadAttention,
)


class Decoder(nn.Module):
    r"""
    The decoder of the Attention Model.

    .. math::
        \{\log(\pmb{p}_t)\},\pi =  \mathrm{Decoder}(s, \pmb{h})

    First of all, precompute the keys and values for the embedding :math:`\pmb{h}`:

    .. math::
        \pmb{k}, \pmb{v}, \pmb{k}^\prime = W^K\pmb{h}, W^V\pmb{h}, W^{K^\prime}\pmb{h}
    and the projection of the graph embedding:

    .. math::
         W_{gc}\bar{\pmb{h}} \quad \text{ for } \bar{\pmb{h}} = \frac{1}{N}\sum\nolimits_i \pmb{h}_i.

    Then, the decoder iterates the decoding process autoregressively.
    In each decoding step, we perform multiple attentions to get the logits for each node.

    .. math::
        \begin{aligned}
        \pmb{h}_{(c)} &= [\bar{\pmb{h}}, \text{Context}(s,\pmb{h})]                                                 \\
        q & = W^Q \pmb{h}_{(c)} = W_{gc}\bar{\pmb{h}} + W_{sc}\text{Context}(s,\pmb{h})                               \\
        q_{gl} &= \mathrm{MultiHeadAttention}(q,\pmb{k},\pmb{v},\mathrm{mask}_t)                                    \\
        \pmb{p}_t &= \mathrm{Softmax}(\mathrm{AttentionScore}_{\text{clip}}(q_{gl},\pmb{k}^\prime, \mathrm{mask}_t))\\
        \pi_{t} &= \mathrm{DecodingStartegy}(\pmb{p}_t)                                                             \\
        \mathrm{mask}_{t+1} &= \mathrm{mask}_t.update(\pi_t).
        \end{aligned}



    .. note::
        If there are dynamic node features specified by :mod:`.dynamic_embedding` ,
        the keys and values projections are updated in each decoding step by

        .. math::
            \begin{aligned}
            \pmb{k}_{\text{dynamic}}, \pmb{v}_{\text{dynamic}}, \pmb{k}^\prime_{\text{dynamic}} &= \mathrm{DynamicEmbedding}(s)\\
            \pmb{k} &= \pmb{k} + \pmb{k}_{\text{dynamic}}\\
            \pmb{v} &= \pmb{v} +\pmb{v}_{\text{dynamic}} \\
            \pmb{k}^\prime &= \pmb{k}^\prime +\pmb{k}^\prime_{\text{dynamic}}.
            \end{aligned}
    .. seealso::
        * The :math:`\text{Context}` is defined in the :mod:`.context` module.
        * The :math:`\text{AttentionScore}` is defined by the :class:`.AttentionScore` class.
        * The :math:`\text{MultiHeadAttention}` is defined by the :class:`.MultiHeadAttention` class.

    Args:
         embedding_dim : the dimension of the embedded inputs
         step_context_dim : the dimension of the context :math:`\text{Context}(\pmb{x})`
         n_heads: number of heads in the :math:`\mathrm{MultiHeadAttention}`
         problem: an object defining the state and the mask updating rule of the problem
         tanh_clipping : the clipping scale of the pointer (attention layer before output)
    Inputs: input, embeddings
        * **input** : dict of inputs, for example ``{'loc': tensor, 'depot': tensor, 'demand': tensor}`` for CVRP.
        * **embeddings**: [batch, graph_size, embedding_dim]
    Outputs: log_ps, pi
        * **log_ps**: [batch, graph_size, T]
        * **pi**: [batch, T]

    """

    def __init__(self, embedding_dim, step_context_dim, n_heads, problem, tanh_clipping):
        super(Decoder, self).__init__()
        # For each node we compute (glimpse key, glimpse value, logit key) so 3 * embedding_dim

        self.project_node_embeddings = nn.Linear(embedding_dim, 3 * embedding_dim, bias=False)
        self.project_fixed_context = nn.Linear(embedding_dim, embedding_dim, bias=False)
        self.project_step_context = nn.Linear(step_context_dim, embedding_dim, bias=False)

        self.context = AutoContext(problem.NAME, {"context_dim": step_context_dim})
        self.dynamic_embedding = AutoDynamicEmbedding(
            problem.NAME, {"embedding_dim": embedding_dim}
        )
        self.glimpse = MultiHeadAttention(embedding_dim=embedding_dim, n_heads=n_heads)
        self.pointer = AttentionScore(use_tanh=True, C=tanh_clipping)

        self.decode_type = None
        self.problem = problem

    def forward(self, input, embeddings):
        outputs = []
        sequences = []

        state = self.problem.make_state(input)

        # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
        cached_embeddings = self._precompute(embeddings)

        # Perform decoding steps
        while not (state.all_finished()):

            log_p, mask = self.advance(cached_embeddings, state)

            # Select the indices of the next nodes in the sequences, result (batch_size) long
            # Squeeze out steps dimension
            action = self.decode(log_p.exp(), mask)
            state = state.update(action)

            # Collect output of step
            outputs.append(log_p)
            sequences.append(action)

        # Collected lists, return Tensor
        return torch.stack(outputs, 1), torch.stack(sequences, 1)

    def set_decode_type(self, decode_type):
        r"""
        Currently support

        .. code-block:: python

            ["greedy", "sampling"]

        """
        self.decode_type = decode_type

    def decode(self, probs, mask):
        r"""
        Execute the decoding strategy specified by ``self.decode_type``.

        Inputs:
            * **probs**: [batch_size, graph_size]
            * **mask** (bool): [batch_size, graph_size]
        Outputs:
            * **idxs** (int): index of action chosen. [batch_size]
        """
        assert (probs == probs).all(), "Probs should not contain any nans"

        if self.decode_type == "greedy":
            _, selected = probs.max(1)
            assert not mask.gather(
                1, selected.unsqueeze(-1)
            ).data.any(), "Decode greedy: infeasible action has maximum probability"

        elif self.decode_type == "sampling":
            selected = probs.multinomial(1).squeeze(1)

            # Check if sampling went OK, can go wrong due to bug on GPU
            # See https://discuss.pytorch.org/t/bad-behavior-of-multinomial-function/10232
            while mask.gather(1, selected.unsqueeze(-1)).data.any():
                print("Sampled bad values, resampling!")
                selected = probs.multinomial(1).squeeze(1)

        else:
            assert False, "Unknown decode type"
        return selected

    def _precompute(self, embeddings):

        # The fixed context projection of the graph embedding is calculated only once for efficiency
        graph_embed = embeddings.mean(1)
        # fixed context = (batch_size, 1, embed_dim) to make broadcastable with parallel timesteps
        graph_context = self.project_fixed_context(graph_embed).unsqueeze(-2)
        # The projection of the node embeddings for the attention is calculated once up front
        glimpse_key, glimpse_val, logit_key = self.project_node_embeddings(embeddings).chunk(
            3, dim=-1
        )

        cache = (
            embeddings,
            graph_context,
            glimpse_key,
            glimpse_val,
            logit_key,
        )  # single head for the final logit
        return cache

    def advance(self, cached_embeddings, state):

        node_embeddings, graph_context, glimpse_K, glimpse_V, logit_K = cached_embeddings

        # Compute context node embedding: [graph embedding| prev node| problem-state-context]
        # [batch, 1, context dim]
        context = self.context(node_embeddings, state)
        step_context = self.project_step_context(context)  # [batch, 1, embed_dim]
        query = graph_context + step_context  # [batch, 1, embed_dim]

        glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic = self.dynamic_embedding(state)
        glimpse_K = glimpse_K + glimpse_key_dynamic
        glimpse_V = glimpse_V + glimpse_val_dynamic
        logit_K = logit_K + logit_key_dynamic
        # Compute the mask
        mask = state.get_mask()

        # Compute logits (unnormalized log_p)
        logits, glimpse = self.calc_logits(query, glimpse_K, glimpse_V, logit_K, mask)

        return logits, glimpse

    def calc_logits(self, query, glimpse_K, glimpse_V, logit_K, mask):
        # Compute glimpse with multi-head-attention.
        # Then use glimpse as a query to compute logits for each node

        # [batch, 1, embed dim]
        glimpse = self.glimpse(query, glimpse_K, glimpse_V, mask)

        logits = self.pointer(glimpse, logit_K, mask)

        return logits, glimpse