File size: 3,665 Bytes
52933b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from torch import nn

from ...nets.attention_model.decoder import Decoder
from ...nets.attention_model.embedding import AutoEmbedding
from ...nets.attention_model.encoder import GraphAttentionEncoder


class AttentionModel(nn.Module):
    r"""
    The Attention Model from Kool et al.,2019.
    `Link to paper <https://arxiv.org/abs/1803.08475>`_.

    For an instance :math:`s`,

    .. math::
        \log(p_\theta(\pi|s)),\pi = \mathrm{AttentionModel}(s)

    The following is executed:

    .. math::
        \begin{aligned}
        \pmb{x} &= \mathrm{Embedding}(s)                                      \\
        \pmb{h} &=  \mathrm{Encoder}(\pmb{x})                           \\
        \{\log(\pmb{p}_t)\},\pi &=  \mathrm{Decoder}(s, \pmb{h}) \\
        \log(p_\theta(\pi|s)) &= \sum\nolimits_t\log(\pmb{p}_{t,\pi_t})
        \end{aligned}
    where :math:`\pmb{h}_i` is the node embedding for each node :math:`i` in the graph.

    In a nutshell, :math:`\mathrm{Embedding}` is a linear projection.
    The :math:`\mathrm{Encoder}` is a transformer.
    The :math:`\mathrm{Decoder}` uses (multi-head) attentions.
    The policy (sequence of action) :math:`\pi` is decoded autoregressively.
    The log-likelihood :math:`\log(p_\theta(\pi|s))` of the policy is also returned.

    .. seealso::
        The definition of :math:`\mathrm{Embedding}`, :math:`\mathrm{Encoder}`, and
        :math:`\mathrm{Decoder}` can be found in the
        :mod:`.embedding`, :mod:`.encoder`, :mod:`.decoder` modules.

    Args:
        embedding_dim : the dimension of the embedded inputs
        hidden_dim : the dimension of the hidden state of the encoder
        problem : an object defining the state of the problem
        n_encode_layers: number of encoder layers
        tanh_clipping : the clipping scale for the decoder
    Inputs: inputs
        * **inputs**: problem instance :math:`s`. [batch, graph_size, input_dim]
    Outputs: ll, pi
        * **ll**: :math:`\log(p_\theta(\pi|s))` the log-likelihood of the policy. [batch, T]
        * **pi**: the policy generated :math:`\pi`. [batch, T]
    """

    def __init__(
        self,
        embedding_dim,
        hidden_dim,
        problem,
        n_encode_layers=2,
        tanh_clipping=10.0,
        normalization="batch",
        n_heads=8,
    ):
        super(AttentionModel, self).__init__()

        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.n_encode_layers = n_encode_layers
        self.decode_type = None
        self.n_heads = n_heads

        self.problem = problem

        self.embedding = AutoEmbedding(problem.NAME, {"embedding_dim": embedding_dim})
        step_context_dim = self.embedding.context_dim

        self.encoder = GraphAttentionEncoder(
            n_heads=n_heads,
            embed_dim=embedding_dim,
            n_layers=self.n_encode_layers,
        )
        self.decoder = Decoder(embedding_dim, step_context_dim, n_heads, problem, tanh_clipping)

    def set_decode_type(self, decode_type):
        self.decoder.set_decode_type(decode_type)

    def forward(self, input):
        embedding = self.embedding(input)
        encoded_inputs, _ = self.encoder(embedding)
        _log_p, pi = self.decoder(input, encoded_inputs)
        ll = self._calc_log_likelihood(_log_p, pi)
        return ll, pi

    def _calc_log_likelihood(self, _log_p, pi):

        # Get log_p corresponding to selected actions
        log_p = _log_p.gather(2, pi.unsqueeze(-1)).squeeze(-1)

        assert (log_p > -1000).data.all(), "Logprobs should not be -inf, check sampling procedure!"

        # Calculate log_likelihood
        return log_p.sum(1)