Spaces:
Sleeping
Sleeping
"""Added ConMamba and Mamba | |
Authors | |
* Xilin Jiang 2024 | |
""" | |
"""Transformer implementation in the SpeechBrain style. | |
Authors | |
* Jianyuan Zhong 2020 | |
* Samuele Cornell 2021 | |
""" | |
import math | |
from typing import Optional | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import speechbrain as sb | |
from speechbrain.nnet.activations import Swish | |
from speechbrain.nnet.attention import RelPosEncXL | |
from speechbrain.nnet.CNN import Conv1d | |
from modules.Conformer import ConformerEncoder | |
from modules.Conmamba import ConmambaEncoder, MambaDecoder | |
class TransformerInterface(nn.Module): | |
"""This is an interface for transformer model. | |
Users can modify the attributes and define the forward function as | |
needed according to their own tasks. | |
The architecture is based on the paper "Attention Is All You Need": | |
https://arxiv.org/pdf/1706.03762.pdf | |
Arguments | |
--------- | |
d_model: int | |
The number of expected features in the encoder/decoder inputs (default=512). | |
nhead: int | |
The number of heads in the multi-head attention models (default=8). | |
num_encoder_layers: int, optional | |
The number of encoder layers in1ì the encoder. | |
num_decoder_layers: int, optional | |
The number of decoder layers in the decoder. | |
d_ffn: int, optional | |
The dimension of the feedforward network model hidden layer. | |
dropout: int, optional | |
The dropout value. | |
activation: torch.nn.Module, optional | |
The activation function for Feed-Forward Network layer, | |
e.g., relu or gelu or swish. | |
custom_src_module: torch.nn.Module, optional | |
Module that processes the src features to expected feature dim. | |
custom_tgt_module: torch.nn.Module, optional | |
Module that processes the src features to expected feature dim. | |
positional_encoding: str, optional | |
Type of positional encoding used. e.g. 'fixed_abs_sine' for fixed absolute positional encodings. | |
normalize_before: bool, optional | |
Whether normalization should be applied before or after MHA or FFN in Transformer layers. | |
Defaults to True as this was shown to lead to better performance and training stability. | |
kernel_size: int, optional | |
Kernel size in convolutional layers when Conformer is used. | |
bias: bool, optional | |
Whether to use bias in Conformer convolutional layers. | |
encoder_module: str, optional | |
Choose between Branchformer, Conformer, ConMamba, and Transformer for the encoder. | |
decoder_module: str, optional | |
Choose between Mamba and Transformer for the decoder. | |
conformer_activation: torch.nn.Module, optional | |
Activation module used after Conformer convolutional layers. E.g. Swish, ReLU etc. it has to be a torch Module. | |
branchformer_activation: torch.nn.Module, optional | |
Activation module used within the Branchformer Encoder. E.g. Swish, ReLU etc. it has to be a torch Module. | |
attention_type: str, optional | |
Type of attention layer used in all Transformer or Conformer layers. | |
e.g. regularMHA or RelPosMHA. | |
max_length: int, optional | |
Max length for the target and source sequence in input. | |
Used for positional encodings. | |
causal: bool, optional | |
Whether the encoder should be causal or not (the decoder is always causal). | |
If causal the Conformer convolutional layer is causal. | |
encoder_kdim: int, optional | |
Dimension of the key for the encoder. | |
encoder_vdim: int, optional | |
Dimension of the value for the encoder. | |
decoder_kdim: int, optional | |
Dimension of the key for the decoder. | |
decoder_vdim: int, optional | |
Dimension of the value for the decoder. | |
csgu_linear_units: int, optional | |
Number of neurons in the hidden linear units of the CSGU Module. | |
-> Branchformer | |
gate_activation: torch.nn.Module, optional | |
Activation function used at the gate of the CSGU module. | |
-> Branchformer | |
use_linear_after_conv: bool, optional | |
If True, will apply a linear transformation of size input_size//2. | |
-> Branchformer | |
mamba_config: dict, optional | |
Mamba parameters if encoder_module or decoder_module is Mamba or ConMamba | |
""" | |
def __init__( | |
self, | |
d_model=512, | |
nhead=8, | |
num_encoder_layers=6, | |
num_decoder_layers=6, | |
d_ffn=2048, | |
dropout=0.1, | |
activation=nn.ReLU, | |
custom_src_module=None, | |
custom_tgt_module=None, | |
positional_encoding="fixed_abs_sine", | |
normalize_before=True, | |
kernel_size: Optional[int] = 31, | |
bias: Optional[bool] = True, | |
encoder_module: Optional[str] = "transformer", | |
decoder_module: Optional[str] = "transformer", | |
conformer_activation: Optional[nn.Module] = Swish, | |
branchformer_activation: Optional[nn.Module] = nn.GELU, | |
attention_type: Optional[str] = "regularMHA", | |
max_length: Optional[int] = 2500, | |
causal: Optional[bool] = False, | |
encoder_kdim: Optional[int] = None, | |
encoder_vdim: Optional[int] = None, | |
decoder_kdim: Optional[int] = None, | |
decoder_vdim: Optional[int] = None, | |
csgu_linear_units: Optional[int] = 3072, | |
gate_activation: Optional[nn.Module] = nn.Identity, | |
use_linear_after_conv: Optional[bool] = False, | |
mamba_config=None | |
): | |
super().__init__() | |
self.causal = causal | |
self.attention_type = attention_type | |
self.positional_encoding_type = positional_encoding | |
self.encoder_kdim = encoder_kdim | |
self.encoder_vdim = encoder_vdim | |
self.decoder_kdim = decoder_kdim | |
self.decoder_vdim = decoder_vdim | |
assert attention_type in ["regularMHA", "RelPosMHAXL", "hypermixing"] | |
assert positional_encoding in ["fixed_abs_sine", None] | |
assert ( | |
num_encoder_layers + num_decoder_layers > 0 | |
), "number of encoder layers and number of decoder layers cannot both be 0!" | |
if positional_encoding == "fixed_abs_sine": | |
self.positional_encoding = PositionalEncoding(d_model, max_length) | |
elif positional_encoding is None: | |
pass | |
# no positional encodings | |
# overrides any other pos_embedding | |
if attention_type == "RelPosMHAXL": | |
self.positional_encoding = RelPosEncXL(d_model) | |
self.positional_encoding_decoder = PositionalEncoding( | |
d_model, max_length | |
) | |
# initialize the encoder | |
if num_encoder_layers > 0: | |
if custom_src_module is not None: | |
self.custom_src_module = custom_src_module(d_model) | |
if encoder_module == "transformer": | |
self.encoder = TransformerEncoder( | |
nhead=nhead, | |
num_layers=num_encoder_layers, | |
d_ffn=d_ffn, | |
d_model=d_model, | |
dropout=dropout, | |
activation=activation, | |
normalize_before=normalize_before, | |
causal=self.causal, | |
attention_type=self.attention_type, | |
kdim=self.encoder_kdim, | |
vdim=self.encoder_vdim, | |
) | |
elif encoder_module == "conformer": | |
self.encoder = ConformerEncoder( | |
nhead=nhead, | |
num_layers=num_encoder_layers, | |
d_ffn=d_ffn, | |
d_model=d_model, | |
dropout=dropout, | |
activation=conformer_activation, | |
kernel_size=kernel_size, | |
bias=bias, | |
causal=self.causal, | |
attention_type=self.attention_type, | |
) | |
assert ( | |
normalize_before | |
), "normalize_before must be True for Conformer" | |
assert ( | |
conformer_activation is not None | |
), "conformer_activation must not be None" | |
elif encoder_module == "branchformer": | |
self.encoder = BranchformerEncoder( | |
nhead=nhead, | |
num_layers=num_encoder_layers, | |
d_model=d_model, | |
dropout=dropout, | |
activation=branchformer_activation, | |
kernel_size=kernel_size, | |
attention_type=self.attention_type, | |
csgu_linear_units=csgu_linear_units, | |
gate_activation=gate_activation, | |
use_linear_after_conv=use_linear_after_conv, | |
) | |
elif encoder_module == "conmamba": | |
self.encoder = ConmambaEncoder( | |
num_layers=num_encoder_layers, | |
d_model=d_model, | |
d_ffn=d_ffn, | |
dropout=dropout, | |
activation=branchformer_activation, | |
kernel_size=kernel_size, | |
bias=bias, | |
causal=self.causal, | |
mamba_config=mamba_config | |
) | |
assert ( | |
normalize_before | |
), "normalize_before must be True for Conmamba" | |
assert ( | |
conformer_activation is not None | |
), "conformer_activation must not be None" | |
# initialize the decoder | |
if num_decoder_layers > 0: | |
if custom_tgt_module is not None: | |
self.custom_tgt_module = custom_tgt_module(d_model) | |
if decoder_module == 'transformer': | |
self.decoder = TransformerDecoder( | |
num_layers=num_decoder_layers, | |
nhead=nhead, | |
d_ffn=d_ffn, | |
d_model=d_model, | |
dropout=dropout, | |
activation=activation, | |
normalize_before=normalize_before, | |
causal=True, | |
attention_type="regularMHA", # always use regular attention in decoder | |
kdim=self.decoder_kdim, | |
vdim=self.decoder_vdim, | |
) | |
elif decoder_module in ['mamba']: | |
self.decoder = MambaDecoder( | |
num_layers=num_decoder_layers, | |
d_ffn=d_ffn, | |
d_model=d_model, | |
activation=activation, | |
dropout=dropout, | |
normalize_before=normalize_before, | |
mamba_config=mamba_config | |
) | |
else: | |
raise NotImplementedError(decoder_module) | |
def forward(self, **kwags): | |
"""Users should modify this function according to their own tasks.""" | |
raise NotImplementedError | |
class PositionalEncoding(nn.Module): | |
"""This class implements the absolute sinusoidal positional encoding function. | |
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) | |
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) | |
Arguments | |
--------- | |
input_size: int | |
Embedding dimension. | |
max_len : int, optional | |
Max length of the input sequences (default 2500). | |
Example | |
------- | |
>>> a = torch.rand((8, 120, 512)) | |
>>> enc = PositionalEncoding(input_size=a.shape[-1]) | |
>>> b = enc(a) | |
>>> b.shape | |
torch.Size([1, 120, 512]) | |
""" | |
def __init__(self, input_size, max_len=2500): | |
super().__init__() | |
if input_size % 2 != 0: | |
raise ValueError( | |
f"Cannot use sin/cos positional encoding with odd channels (got channels={input_size})" | |
) | |
self.max_len = max_len | |
pe = torch.zeros(self.max_len, input_size, requires_grad=False) | |
positions = torch.arange(0, self.max_len).unsqueeze(1).float() | |
denominator = torch.exp( | |
torch.arange(0, input_size, 2).float() | |
* -(math.log(10000.0) / input_size) | |
) | |
pe[:, 0::2] = torch.sin(positions * denominator) | |
pe[:, 1::2] = torch.cos(positions * denominator) | |
pe = pe.unsqueeze(0) | |
self.register_buffer("pe", pe) | |
def forward(self, x): | |
""" | |
Arguments | |
--------- | |
x : torch.Tensor | |
Input feature shape (batch, time, fea) | |
Returns | |
------- | |
The positional encoding. | |
""" | |
return self.pe[:, : x.size(1)].clone().detach() | |
class TransformerEncoderLayer(nn.Module): | |
"""This is an implementation of self-attention encoder layer. | |
Arguments | |
--------- | |
d_ffn: int, optional | |
The dimension of the feedforward network model hidden layer. | |
nhead: int | |
The number of heads in the multi-head attention models (default=8). | |
d_model: int | |
The number of expected features in the encoder/decoder inputs (default=512). | |
kdim: int, optional | |
Dimension of the key. | |
vdim: int, optional | |
Dimension of the value. | |
dropout: int, optional | |
The dropout value. | |
activation: torch.nn.Module, optional | |
The activation function for Feed-Forward Network layer, | |
e.g., relu or gelu or swish. | |
normalize_before: bool, optional | |
Whether normalization should be applied before or after MHA or FFN in Transformer layers. | |
Defaults to True as this was shown to lead to better performance and training stability. | |
attention_type: str, optional | |
Type of attention layer used in all Transformer or Conformer layers. | |
e.g. regularMHA or RelPosMHA. | |
ffn_type: str | |
type of ffn: regularFFN/1dcnn | |
ffn_cnn_kernel_size_list: list of int | |
kernel size of 2 1d-convs if ffn_type is 1dcnn | |
causal: bool, optional | |
Whether the encoder should be causal or not (the decoder is always causal). | |
If causal the Conformer convolutional layer is causal. | |
Example | |
------- | |
>>> import torch | |
>>> x = torch.rand((8, 60, 512)) | |
>>> net = TransformerEncoderLayer(512, 8, d_model=512) | |
>>> output = net(x) | |
>>> output[0].shape | |
torch.Size([8, 60, 512]) | |
""" | |
def __init__( | |
self, | |
d_ffn, | |
nhead, | |
d_model, | |
kdim=None, | |
vdim=None, | |
dropout=0.0, | |
activation=nn.ReLU, | |
normalize_before=False, | |
attention_type="regularMHA", | |
ffn_type="regularFFN", | |
ffn_cnn_kernel_size_list=[3, 3], | |
causal=False, | |
): | |
super().__init__() | |
if attention_type == "regularMHA": | |
self.self_att = sb.nnet.attention.MultiheadAttention( | |
nhead=nhead, | |
d_model=d_model, | |
dropout=dropout, | |
kdim=kdim, | |
vdim=vdim, | |
) | |
elif attention_type == "RelPosMHAXL": | |
self.self_att = sb.nnet.attention.RelPosMHAXL( | |
d_model, nhead, dropout, mask_pos_future=causal | |
) | |
elif attention_type == "hypermixing": | |
self.self_att = sb.nnet.hypermixing.HyperMixing( | |
input_output_dim=d_model, | |
hypernet_size=d_ffn, | |
tied=False, | |
num_heads=nhead, | |
fix_tm_hidden_size=False, | |
) | |
if ffn_type == "regularFFN": | |
self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward( | |
d_ffn=d_ffn, | |
input_size=d_model, | |
dropout=dropout, | |
activation=activation, | |
) | |
elif ffn_type == "1dcnn": | |
self.pos_ffn = nn.Sequential( | |
Conv1d( | |
in_channels=d_model, | |
out_channels=d_ffn, | |
kernel_size=ffn_cnn_kernel_size_list[0], | |
padding="causal" if causal else "same", | |
), | |
nn.ReLU(), | |
Conv1d( | |
in_channels=d_ffn, | |
out_channels=d_model, | |
kernel_size=ffn_cnn_kernel_size_list[1], | |
padding="causal" if causal else "same", | |
), | |
) | |
self.norm1 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) | |
self.norm2 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) | |
self.dropout1 = torch.nn.Dropout(dropout) | |
self.dropout2 = torch.nn.Dropout(dropout) | |
self.normalize_before = normalize_before | |
self.pos_ffn_type = ffn_type | |
def forward( | |
self, | |
src, | |
src_mask: Optional[torch.Tensor] = None, | |
src_key_padding_mask: Optional[torch.Tensor] = None, | |
pos_embs: Optional[torch.Tensor] = None, | |
): | |
""" | |
Arguments | |
--------- | |
src : torch.Tensor | |
The sequence to the encoder layer. | |
src_mask : torch.Tensor | |
The mask for the src query for each example in the batch. | |
src_key_padding_mask : torch.Tensor, optional | |
The mask for the src keys for each example in the batch. | |
pos_embs: torch.Tensor, optional | |
The positional embeddings tensor. | |
Returns | |
------- | |
output : torch.Tensor | |
The output of the transformer encoder layer. | |
""" | |
if self.normalize_before: | |
src1 = self.norm1(src) | |
else: | |
src1 = src | |
output, self_attn = self.self_att( | |
src1, | |
src1, | |
src1, | |
attn_mask=src_mask, | |
key_padding_mask=src_key_padding_mask, | |
pos_embs=pos_embs, | |
) | |
# add & norm | |
src = src + self.dropout1(output) | |
if not self.normalize_before: | |
src = self.norm1(src) | |
if self.normalize_before: | |
src1 = self.norm2(src) | |
else: | |
src1 = src | |
output = self.pos_ffn(src1) | |
# add & norm | |
output = src + self.dropout2(output) | |
if not self.normalize_before: | |
output = self.norm2(output) | |
return output, self_attn | |
class TransformerEncoder(nn.Module): | |
"""This class implements the transformer encoder. | |
Arguments | |
--------- | |
num_layers : int | |
Number of transformer layers to include. | |
nhead : int | |
Number of attention heads. | |
d_ffn : int | |
Hidden size of self-attention Feed Forward layer. | |
input_shape : tuple | |
Expected shape of the input. | |
d_model : int | |
The dimension of the input embedding. | |
kdim : int | |
Dimension for key (Optional). | |
vdim : int | |
Dimension for value (Optional). | |
dropout : float | |
Dropout for the encoder (Optional). | |
activation: torch.nn.Module, optional | |
The activation function for Feed-Forward Network layer, | |
e.g., relu or gelu or swish. | |
normalize_before: bool, optional | |
Whether normalization should be applied before or after MHA or FFN in Transformer layers. | |
Defaults to True as this was shown to lead to better performance and training stability. | |
causal: bool, optional | |
Whether the encoder should be causal or not (the decoder is always causal). | |
If causal the Conformer convolutional layer is causal. | |
layerdrop_prob: float | |
The probability to drop an entire layer | |
attention_type: str, optional | |
Type of attention layer used in all Transformer or Conformer layers. | |
e.g. regularMHA or RelPosMHA. | |
ffn_type: str | |
type of ffn: regularFFN/1dcnn | |
ffn_cnn_kernel_size_list: list of int | |
conv kernel size of 2 1d-convs if ffn_type is 1dcnn | |
Example | |
------- | |
>>> import torch | |
>>> x = torch.rand((8, 60, 512)) | |
>>> net = TransformerEncoder(1, 8, 512, d_model=512) | |
>>> output, _ = net(x) | |
>>> output.shape | |
torch.Size([8, 60, 512]) | |
""" | |
def __init__( | |
self, | |
num_layers, | |
nhead, | |
d_ffn, | |
input_shape=None, | |
d_model=None, | |
kdim=None, | |
vdim=None, | |
dropout=0.0, | |
activation=nn.ReLU, | |
normalize_before=False, | |
causal=False, | |
layerdrop_prob=0.0, | |
attention_type="regularMHA", | |
ffn_type="regularFFN", | |
ffn_cnn_kernel_size_list=[3, 3], | |
): | |
super().__init__() | |
self.layers = torch.nn.ModuleList( | |
[ | |
TransformerEncoderLayer( | |
d_ffn=d_ffn, | |
nhead=nhead, | |
d_model=d_model, | |
kdim=kdim, | |
vdim=vdim, | |
dropout=dropout, | |
activation=activation, | |
normalize_before=normalize_before, | |
causal=causal, | |
attention_type=attention_type, | |
ffn_type=ffn_type, | |
ffn_cnn_kernel_size_list=ffn_cnn_kernel_size_list, | |
) | |
for i in range(num_layers) | |
] | |
) | |
self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) | |
self.layerdrop_prob = layerdrop_prob | |
self.rng = np.random.default_rng() | |
def forward( | |
self, | |
src, | |
src_mask: Optional[torch.Tensor] = None, | |
src_key_padding_mask: Optional[torch.Tensor] = None, | |
pos_embs: Optional[torch.Tensor] = None, | |
dynchunktrain_config=None, | |
): | |
""" | |
Arguments | |
--------- | |
src : torch.Tensor | |
The sequence to the encoder layer (required). | |
src_mask : torch.Tensor | |
The mask for the src sequence (optional). | |
src_key_padding_mask : torch.Tensor | |
The mask for the src keys per batch (optional). | |
pos_embs : torch.Tensor | |
The positional embedding tensor | |
dynchunktrain_config : config | |
Not supported for this encoder. | |
Returns | |
------- | |
output : torch.Tensor | |
The output of the transformer. | |
attention_lst : list | |
The attention values. | |
""" | |
assert ( | |
dynchunktrain_config is None | |
), "Dynamic Chunk Training unsupported for this encoder" | |
output = src | |
if self.layerdrop_prob > 0.0: | |
keep_probs = self.rng.random(len(self.layers)) | |
else: | |
keep_probs = None | |
attention_lst = [] | |
for i, enc_layer in enumerate(self.layers): | |
if ( | |
not self.training | |
or self.layerdrop_prob == 0.0 | |
or keep_probs[i] > self.layerdrop_prob | |
): | |
output, attention = enc_layer( | |
output, | |
src_mask=src_mask, | |
src_key_padding_mask=src_key_padding_mask, | |
pos_embs=pos_embs, | |
) | |
attention_lst.append(attention) | |
output = self.norm(output) | |
return output, attention_lst | |
class TransformerDecoderLayer(nn.Module): | |
"""This class implements the self-attention decoder layer. | |
Arguments | |
--------- | |
d_ffn : int | |
Hidden size of self-attention Feed Forward layer. | |
nhead : int | |
Number of attention heads. | |
d_model : int | |
Dimension of the model. | |
kdim : int | |
Dimension for key (optional). | |
vdim : int | |
Dimension for value (optional). | |
dropout : float | |
Dropout for the decoder (optional). | |
activation : Callable | |
Function to use between layers, default nn.ReLU | |
normalize_before : bool | |
Whether to normalize before layers. | |
attention_type : str | |
Type of attention to use, "regularMHA" or "RelPosMHAXL" | |
causal : bool | |
Whether to mask future positions. | |
Example | |
------- | |
>>> src = torch.rand((8, 60, 512)) | |
>>> tgt = torch.rand((8, 60, 512)) | |
>>> net = TransformerDecoderLayer(1024, 8, d_model=512) | |
>>> output, self_attn, multihead_attn = net(src, tgt) | |
>>> output.shape | |
torch.Size([8, 60, 512]) | |
""" | |
def __init__( | |
self, | |
d_ffn, | |
nhead, | |
d_model, | |
kdim=None, | |
vdim=None, | |
dropout=0.0, | |
activation=nn.ReLU, | |
normalize_before=False, | |
attention_type="regularMHA", | |
causal=None, | |
): | |
super().__init__() | |
self.nhead = nhead | |
if attention_type == "regularMHA": | |
self.self_attn = sb.nnet.attention.MultiheadAttention( | |
nhead=nhead, | |
d_model=d_model, | |
kdim=kdim, | |
vdim=vdim, | |
dropout=dropout, | |
) | |
self.multihead_attn = sb.nnet.attention.MultiheadAttention( | |
nhead=nhead, | |
d_model=d_model, | |
kdim=kdim, | |
vdim=vdim, | |
dropout=dropout, | |
) | |
elif attention_type == "RelPosMHAXL": | |
self.self_attn = sb.nnet.attention.RelPosMHAXL( | |
d_model, nhead, dropout, mask_pos_future=causal | |
) | |
self.multihead_attn = sb.nnet.attention.RelPosMHAXL( | |
d_model, nhead, dropout, mask_pos_future=causal | |
) | |
self.pos_ffn = sb.nnet.attention.PositionalwiseFeedForward( | |
d_ffn=d_ffn, | |
input_size=d_model, | |
dropout=dropout, | |
activation=activation, | |
) | |
# normalization layers | |
self.norm1 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) | |
self.norm2 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) | |
self.norm3 = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) | |
self.dropout1 = torch.nn.Dropout(dropout) | |
self.dropout2 = torch.nn.Dropout(dropout) | |
self.dropout3 = torch.nn.Dropout(dropout) | |
self.normalize_before = normalize_before | |
def forward( | |
self, | |
tgt, | |
memory, | |
tgt_mask=None, | |
memory_mask=None, | |
tgt_key_padding_mask=None, | |
memory_key_padding_mask=None, | |
pos_embs_tgt=None, | |
pos_embs_src=None, | |
): | |
""" | |
Arguments | |
---------- | |
tgt: torch.Tensor | |
The sequence to the decoder layer (required). | |
memory: torch.Tensor | |
The sequence from the last layer of the encoder (required). | |
tgt_mask: torch.Tensor | |
The mask for the tgt sequence (optional). | |
memory_mask: torch.Tensor | |
The mask for the memory sequence (optional). | |
tgt_key_padding_mask: torch.Tensor | |
The mask for the tgt keys per batch (optional). | |
memory_key_padding_mask: torch.Tensor | |
The mask for the memory keys per batch (optional). | |
pos_embs_tgt: torch.Tensor | |
The positional embeddings for the target (optional). | |
pos_embs_src: torch.Tensor | |
The positional embeddings for the source (optional). | |
""" | |
if self.normalize_before: | |
tgt1 = self.norm1(tgt) | |
else: | |
tgt1 = tgt | |
# self-attention over the target sequence | |
tgt2, self_attn = self.self_attn( | |
query=tgt1, | |
key=tgt1, | |
value=tgt1, | |
attn_mask=tgt_mask, | |
key_padding_mask=tgt_key_padding_mask, | |
pos_embs=pos_embs_tgt, | |
) | |
# add & norm | |
tgt = tgt + self.dropout1(tgt2) | |
if not self.normalize_before: | |
tgt = self.norm1(tgt) | |
if self.normalize_before: | |
tgt1 = self.norm2(tgt) | |
else: | |
tgt1 = tgt | |
# multi-head attention over the target sequence and encoder states | |
tgt2, multihead_attention = self.multihead_attn( | |
query=tgt1, | |
key=memory, | |
value=memory, | |
attn_mask=memory_mask, | |
key_padding_mask=memory_key_padding_mask, | |
pos_embs=pos_embs_src, | |
) | |
# add & norm | |
tgt = tgt + self.dropout2(tgt2) | |
if not self.normalize_before: | |
tgt = self.norm2(tgt) | |
if self.normalize_before: | |
tgt1 = self.norm3(tgt) | |
else: | |
tgt1 = tgt | |
tgt2 = self.pos_ffn(tgt1) | |
# add & norm | |
tgt = tgt + self.dropout3(tgt2) | |
if not self.normalize_before: | |
tgt = self.norm3(tgt) | |
return tgt, self_attn, multihead_attention | |
class TransformerDecoder(nn.Module): | |
"""This class implements the Transformer decoder. | |
Arguments | |
--------- | |
num_layers : int | |
Number of transformer layers for the decoder. | |
nhead : int | |
Number of attention heads. | |
d_ffn : int | |
Hidden size of self-attention Feed Forward layer. | |
d_model : int | |
Dimension of the model. | |
kdim : int, optional | |
Dimension for key (Optional). | |
vdim : int, optional | |
Dimension for value (Optional). | |
dropout : float, optional | |
Dropout for the decoder (Optional). | |
activation : Callable | |
The function to apply between layers, default nn.ReLU | |
normalize_before : bool | |
Whether to normalize before layers. | |
causal : bool | |
Whether to allow future information in decoding. | |
attention_type : str | |
Type of attention to use, "regularMHA" or "RelPosMHAXL" | |
Example | |
------- | |
>>> src = torch.rand((8, 60, 512)) | |
>>> tgt = torch.rand((8, 60, 512)) | |
>>> net = TransformerDecoder(1, 8, 1024, d_model=512) | |
>>> output, _, _ = net(src, tgt) | |
>>> output.shape | |
torch.Size([8, 60, 512]) | |
""" | |
def __init__( | |
self, | |
num_layers, | |
nhead, | |
d_ffn, | |
d_model, | |
kdim=None, | |
vdim=None, | |
dropout=0.0, | |
activation=nn.ReLU, | |
normalize_before=False, | |
causal=False, | |
attention_type="regularMHA", | |
): | |
super().__init__() | |
self.layers = torch.nn.ModuleList( | |
[ | |
TransformerDecoderLayer( | |
d_ffn=d_ffn, | |
nhead=nhead, | |
d_model=d_model, | |
kdim=kdim, | |
vdim=vdim, | |
dropout=dropout, | |
activation=activation, | |
normalize_before=normalize_before, | |
causal=causal, | |
attention_type=attention_type, | |
) | |
for _ in range(num_layers) | |
] | |
) | |
self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) | |
def forward( | |
self, | |
tgt, | |
memory, | |
tgt_mask=None, | |
memory_mask=None, | |
tgt_key_padding_mask=None, | |
memory_key_padding_mask=None, | |
pos_embs_tgt=None, | |
pos_embs_src=None, | |
): | |
""" | |
Arguments | |
---------- | |
tgt : torch.Tensor | |
The sequence to the decoder layer (required). | |
memory : torch.Tensor | |
The sequence from the last layer of the encoder (required). | |
tgt_mask : torch.Tensor | |
The mask for the tgt sequence (optional). | |
memory_mask : torch.Tensor | |
The mask for the memory sequence (optional). | |
tgt_key_padding_mask : torch.Tensor | |
The mask for the tgt keys per batch (optional). | |
memory_key_padding_mask : torch.Tensor | |
The mask for the memory keys per batch (optional). | |
pos_embs_tgt : torch.Tensor | |
The positional embeddings for the target (optional). | |
pos_embs_src : torch.Tensor | |
The positional embeddings for the source (optional). | |
""" | |
output = tgt | |
self_attns, multihead_attns = [], [] | |
for dec_layer in self.layers: | |
output, self_attn, multihead_attn = dec_layer( | |
output, | |
memory, | |
tgt_mask=tgt_mask, | |
memory_mask=memory_mask, | |
tgt_key_padding_mask=tgt_key_padding_mask, | |
memory_key_padding_mask=memory_key_padding_mask, | |
pos_embs_tgt=pos_embs_tgt, | |
pos_embs_src=pos_embs_src, | |
) | |
self_attns.append(self_attn) | |
multihead_attns.append(multihead_attn) | |
output = self.norm(output) | |
return output, self_attns, multihead_attns | |
class NormalizedEmbedding(nn.Module): | |
"""This class implements the normalized embedding layer for the transformer. | |
Since the dot product of the self-attention is always normalized by sqrt(d_model) | |
and the final linear projection for prediction shares weight with the embedding layer, | |
we multiply the output of the embedding by sqrt(d_model). | |
Arguments | |
--------- | |
d_model: int | |
The number of expected features in the encoder/decoder inputs (default=512). | |
vocab: int | |
The vocab size. | |
Example | |
------- | |
>>> emb = NormalizedEmbedding(512, 1000) | |
>>> trg = torch.randint(0, 999, (8, 50)) | |
>>> emb_fea = emb(trg) | |
""" | |
def __init__(self, d_model, vocab): | |
super().__init__() | |
self.emb = sb.nnet.embedding.Embedding( | |
num_embeddings=vocab, embedding_dim=d_model, blank_id=0 | |
) | |
self.d_model = d_model | |
def forward(self, x): | |
"""Processes the input tensor x and returns an output tensor.""" | |
return self.emb(x) * math.sqrt(self.d_model) | |
def get_key_padding_mask(padded_input, pad_idx): | |
"""Creates a binary mask to prevent attention to padded locations. | |
We suggest using ``get_mask_from_lengths`` instead of this function. | |
Arguments | |
--------- | |
padded_input: torch.Tensor | |
Padded input. | |
pad_idx: int | |
idx for padding element. | |
Returns | |
------- | |
key_padded_mask: torch.Tensor | |
Binary mask to prevent attention to padding. | |
Example | |
------- | |
>>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]]) | |
>>> get_key_padding_mask(a, pad_idx=0) | |
tensor([[False, False, True], | |
[False, False, True], | |
[False, False, True]]) | |
""" | |
if len(padded_input.shape) == 4: | |
bz, time, ch1, ch2 = padded_input.shape | |
padded_input = padded_input.reshape(bz, time, ch1 * ch2) | |
key_padded_mask = padded_input.eq(pad_idx).to(padded_input.device) | |
# if the input is more than 2d, mask the locations where they are silence | |
# across all channels | |
if len(padded_input.shape) > 2: | |
key_padded_mask = key_padded_mask.float().prod(dim=-1).bool() | |
return key_padded_mask.detach() | |
return key_padded_mask.detach() | |
def get_lookahead_mask(padded_input): | |
"""Creates a binary mask for each sequence which masks future frames. | |
Arguments | |
--------- | |
padded_input: torch.Tensor | |
Padded input tensor. | |
Returns | |
------- | |
mask : torch.Tensor | |
Binary mask for masking future frames. | |
Example | |
------- | |
>>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]]) | |
>>> get_lookahead_mask(a) | |
tensor([[0., -inf, -inf], | |
[0., 0., -inf], | |
[0., 0., 0.]]) | |
""" | |
seq_len = padded_input.shape[1] | |
mask = ( | |
torch.triu(torch.ones((seq_len, seq_len), device=padded_input.device)) | |
== 1 | |
).transpose(0, 1) | |
mask = ( | |
mask.float() | |
.masked_fill(mask == 0, float("-inf")) | |
.masked_fill(mask == 1, float(0.0)) | |
) | |
return mask.detach().to(padded_input.device) | |
def get_mask_from_lengths(lengths, max_len=None): | |
"""Creates a binary mask from sequence lengths | |
Arguments | |
--------- | |
lengths: torch.Tensor | |
A tensor of sequence lengths | |
max_len: int (Optional) | |
Maximum sequence length, defaults to None. | |
Returns | |
------- | |
mask: torch.Tensor | |
the mask where padded elements are set to True. | |
Then one can use tensor.masked_fill_(mask, 0) for the masking. | |
Example | |
------- | |
>>> lengths = torch.tensor([3, 2, 4]) | |
>>> get_mask_from_lengths(lengths) | |
tensor([[False, False, False, True], | |
[False, False, True, True], | |
[False, False, False, False]]) | |
""" | |
if max_len is None: | |
max_len = torch.max(lengths).item() | |
seq_range = torch.arange( | |
max_len, device=lengths.device, dtype=lengths.dtype | |
) | |
return ~(seq_range.unsqueeze(0) < lengths.unsqueeze(1)) | |