File size: 1,205 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
# encoding: utf-8
"""Class Declaration of Transformer's Positional Encoding."""
import chainer
import chainer.functions as F
import numpy as np
class PositionalEncoding(chainer.Chain):
"""Positional encoding module.
:param int n_units: embedding dim
:param float dropout: dropout rate
:param int length: maximum input length
"""
def __init__(self, n_units, dropout=0.1, length=5000):
"""Initialize Positional Encoding."""
# Implementation described in the paper
super(PositionalEncoding, self).__init__()
self.dropout = dropout
posi_block = np.arange(0, length, dtype=np.float32)[:, None]
unit_block = np.exp(
np.arange(0, n_units, 2, dtype=np.float32) * -(np.log(10000.0) / n_units)
)
self.pe = np.zeros((length, n_units), dtype=np.float32)
self.pe[:, ::2] = np.sin(posi_block * unit_block)
self.pe[:, 1::2] = np.cos(posi_block * unit_block)
self.scale = np.sqrt(n_units)
def forward(self, e):
"""Forward Positional Encoding."""
length = e.shape[1]
e = e * self.scale + self.xp.array(self.pe[:length])
return F.dropout(e, self.dropout)
|