Patrick WAN
initial commit
52933b5
from torch import nn
from ...nets.attention_model.multi_head_attention import MultiHeadAttentionProj
class SkipConnection(nn.Module):
def __init__(self, module):
super(SkipConnection, self).__init__()
self.module = module
def forward(self, input):
return input + self.module(input)
class Normalization(nn.Module):
def __init__(self, embedding_dim):
super(Normalization, self).__init__()
self.normalizer = nn.BatchNorm1d(embedding_dim, affine=True)
def forward(self, input):
# out = self.normalizer(input.permute(0,2,1)).permute(0,2,1) # slightly different 3e-6
# return out
return self.normalizer(input.view(-1, input.size(-1))).view(input.size())
class MultiHeadAttentionLayer(nn.Sequential):
r"""
A layer with attention mechanism and normalization.
For an embedding :math:`\pmb{x}`,
.. math::
\pmb{h} = \mathrm{MultiHeadAttentionLayer}(\pmb{x})
The following is executed:
.. math::
\begin{aligned}
\pmb{x}_0&=\pmb{x}+\mathrm{MultiHeadAttentionProj}(\pmb{x}) \\
\pmb{x}_1&=\mathrm{BatchNorm}(\pmb{x}_0) \\
\pmb{x}_2&=\pmb{x}_1+\mathrm{MLP_{\text{2 layers}}}(\pmb{x}_1)\\
\pmb{h} &=\mathrm{BatchNorm}(\pmb{x}_2)
\end{aligned}
.. seealso::
The :math:`\mathrm{MultiHeadAttentionProj}` computes the self attention
of the embedding :math:`\pmb{x}`. Check :class:`~.MultiHeadAttentionProj` for details.
Args:
n_heads : number of heads
embedding_dim : dimension of the query, keys, values
feed_forward_hidden : size of the hidden layer in the MLP
Inputs: inputs
* **inputs**: embeddin :math:`\pmb{x}`. [batch, graph_size, embedding_dim]
Outputs: out
* **out**: the output :math:`\pmb{h}` [batch, graph_size, embedding_dim]
"""
def __init__(
self,
n_heads,
embedding_dim,
feed_forward_hidden=512,
):
super(MultiHeadAttentionLayer, self).__init__(
SkipConnection(
MultiHeadAttentionProj(
embedding_dim=embedding_dim,
n_heads=n_heads,
)
),
Normalization(embedding_dim),
SkipConnection(
nn.Sequential(
nn.Linear(embedding_dim, feed_forward_hidden),
nn.ReLU(),
nn.Linear(feed_forward_hidden, embedding_dim),
)
if feed_forward_hidden > 0
else nn.Linear(embedding_dim, embedding_dim)
),
Normalization(embedding_dim),
)
class GraphAttentionEncoder(nn.Module):
r"""
Graph attention by self attention on graph nodes.
For an embedding :math:`\pmb{x}`, repeat ``n_layers`` time:
.. math::
\pmb{h} = \mathrm{MultiHeadAttentionLayer}(\pmb{x})
.. seealso::
Check :class:`~.MultiHeadAttentionLayer` for details.
Args:
n_heads : number of heads
embedding_dim : dimension of the query, keys, values
n_layers : number of :class:`~.MultiHeadAttentionLayer` to iterate.
feed_forward_hidden : size of the hidden layer in the MLP
Inputs: x
* **x**: embeddin :math:`\pmb{x}`. [batch, graph_size, embedding_dim]
Outputs: (h, h_mean)
* **h**: the output :math:`\pmb{h}` [batch, graph_size, embedding_dim]
"""
def __init__(self, n_heads, embed_dim, n_layers, feed_forward_hidden=512):
super(GraphAttentionEncoder, self).__init__()
self.layers = nn.Sequential(
*(
MultiHeadAttentionLayer(n_heads, embed_dim, feed_forward_hidden)
for _ in range(n_layers)
)
)
def forward(self, x, mask=None):
assert mask is None, "TODO mask not yet supported!"
h = self.layers(x)
return (h, h.mean(dim=1))