File size: 10,478 Bytes
b6c45cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 |
from math import ceil
import warnings
import torch.nn as nn
from torch.nn.modules.activation import MultiheadAttention
from ..masknn import activations, norms
import torch
from ..dsp.overlap_add import DualPathProcessing
import inspect
class ImprovedTransformedLayer(nn.Module):
"""
Improved Transformer module as used in [1].
It is Multi-Head self-attention followed by LSTM, activation and linear projection layer.
Args:
embed_dim (int): Number of input channels.
n_heads (int): Number of attention heads.
dim_ff (int): Number of neurons in the RNNs cell state.
Defaults to 256. RNN here replaces standard FF linear layer in plain Transformer.
dropout (float, optional): Dropout ratio, must be in [0,1].
activation (str, optional): activation function applied at the output of RNN.
bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN
(Intra-Chunk is always bidirectional).
norm_type (str, optional): Type of normalization to use.
References:
[1] Chen, Jingjing, Qirong Mao, and Dong Liu.
"Dual-Path Transformer Network: Direct Context-Aware Modeling for End-to-End Monaural Speech Separation."
arXiv preprint arXiv:2007.13975 (2020).
"""
def __init__(
self,
embed_dim,
n_heads,
dim_ff,
dropout=0.0,
activation="relu",
bidirectional=True,
norm="gLN",
):
super(ImprovedTransformedLayer, self).__init__()
self.mha = MultiheadAttention(embed_dim, n_heads, dropout=dropout)
# self.linear_first = nn.Linear(embed_dim, 2 * dim_ff) # Added by Kay. 20201119
self.dropout = nn.Dropout(dropout)
self.recurrent = nn.LSTM(embed_dim, dim_ff, bidirectional=bidirectional, batch_first=True)
ff_inner_dim = 2 * dim_ff if bidirectional else dim_ff
self.linear = nn.Linear(ff_inner_dim, embed_dim)
self.activation = activations.get(activation)()
self.norm_mha = norms.get(norm)(embed_dim)
self.norm_ff = norms.get(norm)(embed_dim)
def forward(self, x):
tomha = x.permute(2, 0, 1)
# x is batch, channels, seq_len
# mha is seq_len, batch, channels
# self-attention is applied
out = self.mha(tomha, tomha, tomha)[0]
x = self.dropout(out.permute(1, 2, 0)) + x
x = self.norm_mha(x)
# lstm is applied
out = self.linear(self.dropout(self.activation(self.recurrent(x.transpose(1, -1))[0])))
x = self.dropout(out.transpose(1, -1)) + x
return self.norm_ff(x)
''' version 0.3.4
def forward(self, x):
x = x.transpose(1, -1)
# x is batch, seq_len, channels
# self-attention is applied
out = self.mha(x, x, x)[0]
x = self.dropout(out) + x
x = self.norm_mha(x.transpose(1, -1)).transpose(1, -1)
# lstm is applied
out = self.linear(self.dropout(self.activation(self.recurrent(x)[0])))
# out = self.linear(self.dropout(self.activation(self.linear_first(x)[0])))
x = self.dropout(out) + x
return self.norm_ff(x.transpose(1, -1))
'''
class DPTransformer(nn.Module):
"""Dual-path Transformer introduced in [1].
Args:
in_chan (int): Number of input filters.
n_src (int): Number of masks to estimate.
n_heads (int): Number of attention heads.
ff_hid (int): Number of neurons in the RNNs cell state.
Defaults to 256.
chunk_size (int): window size of overlap and add processing.
Defaults to 100.
hop_size (int or None): hop size (stride) of overlap and add processing.
Default to `chunk_size // 2` (50% overlap).
n_repeats (int): Number of repeats. Defaults to 6.
norm_type (str, optional): Type of normalization to use.
ff_activation (str, optional): activation function applied at the output of RNN.
mask_act (str, optional): Which non-linear function to generate mask.
bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN
(Intra-Chunk is always bidirectional).
dropout (float, optional): Dropout ratio, must be in [0,1].
References
[1] Chen, Jingjing, Qirong Mao, and Dong Liu. "Dual-Path Transformer
Network: Direct Context-Aware Modeling for End-to-End Monaural Speech Separation."
arXiv (2020).
"""
def __init__(
self,
in_chan,
n_src,
n_heads=4,
ff_hid=256,
chunk_size=100,
hop_size=None,
n_repeats=6,
norm_type="gLN",
ff_activation="relu",
mask_act="relu",
bidirectional=True,
dropout=0,
):
super(DPTransformer, self).__init__()
self.in_chan = in_chan
self.n_src = n_src
self.n_heads = n_heads
self.ff_hid = ff_hid
self.chunk_size = chunk_size
hop_size = hop_size if hop_size is not None else chunk_size // 2
self.hop_size = hop_size
self.n_repeats = n_repeats
self.n_src = n_src
self.norm_type = norm_type
self.ff_activation = ff_activation
self.mask_act = mask_act
self.bidirectional = bidirectional
self.dropout = dropout
# version 0.3.4
# self.in_norm = norms.get(norm_type)(in_chan)
self.mha_in_dim = ceil(self.in_chan / self.n_heads) * self.n_heads
if self.in_chan % self.n_heads != 0:
warnings.warn(
f"DPTransformer input dim ({self.in_chan}) is not a multiple of the number of "
f"heads ({self.n_heads}). Adding extra linear layer at input to accomodate "
f"(size [{self.in_chan} x {self.mha_in_dim}])"
)
self.input_layer = nn.Linear(self.in_chan, self.mha_in_dim)
else:
self.input_layer = None
self.in_norm = norms.get(norm_type)(self.mha_in_dim)
self.ola = DualPathProcessing(self.chunk_size, self.hop_size)
# Succession of DPRNNBlocks.
self.layers = nn.ModuleList([])
for x in range(self.n_repeats):
self.layers.append(
nn.ModuleList(
[
ImprovedTransformedLayer(
self.mha_in_dim,
self.n_heads,
self.ff_hid,
self.dropout,
self.ff_activation,
True,
self.norm_type,
),
ImprovedTransformedLayer(
self.mha_in_dim,
self.n_heads,
self.ff_hid,
self.dropout,
self.ff_activation,
self.bidirectional,
self.norm_type,
),
]
)
)
net_out_conv = nn.Conv2d(self.mha_in_dim, n_src * self.in_chan, 1)
self.first_out = nn.Sequential(nn.PReLU(), net_out_conv)
# Gating and masking in 2D space (after fold)
self.net_out = nn.Sequential(nn.Conv1d(self.in_chan, self.in_chan, 1), nn.Tanh())
self.net_gate = nn.Sequential(nn.Conv1d(self.in_chan, self.in_chan, 1), nn.Sigmoid())
# Get activation function.
mask_nl_class = activations.get(mask_act)
# For softmax, feed the source dimension.
if has_arg(mask_nl_class, "dim"):
self.output_act = mask_nl_class(dim=1)
else:
self.output_act = mask_nl_class()
def forward(self, mixture_w):
r"""Forward.
Args:
mixture_w (:class:`torch.Tensor`): Tensor of shape $(batch, nfilters, nframes)$
Returns:
:class:`torch.Tensor`: estimated mask of shape $(batch, nsrc, nfilters, nframes)$
"""
if self.input_layer is not None:
mixture_w = self.input_layer(mixture_w.transpose(1, 2)).transpose(1, 2)
mixture_w = self.in_norm(mixture_w) # [batch, bn_chan, n_frames]
n_orig_frames = mixture_w.shape[-1]
mixture_w = self.ola.unfold(mixture_w)
batch, n_filters, self.chunk_size, n_chunks = mixture_w.size()
for layer_idx in range(len(self.layers)):
intra, inter = self.layers[layer_idx]
mixture_w = self.ola.intra_process(mixture_w, intra)
mixture_w = self.ola.inter_process(mixture_w, inter)
output = self.first_out(mixture_w)
output = output.reshape(batch * self.n_src, self.in_chan, self.chunk_size, n_chunks)
output = self.ola.fold(output, output_size=n_orig_frames)
output = self.net_out(output) * self.net_gate(output)
# Compute mask
output = output.reshape(batch, self.n_src, self.in_chan, -1)
est_mask = self.output_act(output)
return est_mask
def get_config(self):
config = {
"in_chan": self.in_chan,
"ff_hid": self.ff_hid,
"n_heads": self.n_heads,
"chunk_size": self.chunk_size,
"hop_size": self.hop_size,
"n_repeats": self.n_repeats,
"n_src": self.n_src,
"norm_type": self.norm_type,
"ff_activation": self.ff_activation,
"mask_act": self.mask_act,
"bidirectional": self.bidirectional,
"dropout": self.dropout,
}
return config
def has_arg(fn, name):
"""Checks if a callable accepts a given keyword argument.
Args:
fn (callable): Callable to inspect.
name (str): Check if `fn` can be called with `name` as a keyword
argument.
Returns:
bool: whether `fn` accepts a `name` keyword argument.
"""
signature = inspect.signature(fn)
parameter = signature.parameters.get(name)
if parameter is None:
return False
return parameter.kind in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
)
|