File size: 3,839 Bytes
b6c45cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from ..filterbanks import make_enc_dec
from ..masknn import DPTransformer
from .base_models import BaseEncoderMaskerDecoder


class DPTNet(BaseEncoderMaskerDecoder):
    """DPTNet separation model, as described in [1].



    Args:

        n_src (int): Number of masks to estimate.

        out_chan  (int or None): Number of bins in the estimated masks.

            Defaults to `in_chan`.

        bn_chan (int): Number of channels after the bottleneck.

            Defaults to 128.

        hid_size (int): Number of neurons in the RNNs cell state.

            Defaults to 128.

        chunk_size (int): window size of overlap and add processing.

            Defaults to 100.

        hop_size (int or None): hop size (stride) of overlap and add processing.

            Default to `chunk_size // 2` (50% overlap).

        n_repeats (int): Number of repeats. Defaults to 6.

        norm_type (str, optional): Type of normalization to use. To choose from



            - ``'gLN'``: global Layernorm

            - ``'cLN'``: channelwise Layernorm

        mask_act (str, optional): Which non-linear function to generate mask.

        bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN

            (Intra-Chunk is always bidirectional).

        rnn_type (str, optional): Type of RNN used. Choose between ``'RNN'``,

            ``'LSTM'`` and ``'GRU'``.

        num_layers (int, optional): Number of layers in each RNN.

        dropout (float, optional): Dropout ratio, must be in [0,1].

        in_chan (int, optional): Number of input channels, should be equal to

            n_filters.

        fb_name (str, className): Filterbank family from which to make encoder

            and decoder. To choose among [``'free'``, ``'analytic_free'``,

            ``'param_sinc'``, ``'stft'``].

        n_filters (int): Number of filters / Input dimension of the masker net.

        kernel_size (int): Length of the filters.

        stride (int, optional): Stride of the convolution.

            If None (default), set to ``kernel_size // 2``.

        **fb_kwargs (dict): Additional kwards to pass to the filterbank

            creation.



    References:

        [1]: Jingjing Chen et al. "Dual-Path Transformer Network: Direct

            Context-Aware Modeling for End-to-End Monaural Speech Separation"

            Interspeech 2020.

    """

    def __init__(

        self,

        n_src,

        ff_hid=256,

        chunk_size=100,

        hop_size=None,

        n_repeats=6,

        norm_type="gLN",

        ff_activation="relu",

        encoder_activation="relu",

        mask_act="relu",

        bidirectional=True,

        dropout=0,

        in_chan=None,

        fb_name="free",

        kernel_size=16,

        n_filters=64,

        stride=8,

        **fb_kwargs,

    ):
        encoder, decoder = make_enc_dec(
            fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, **fb_kwargs
        )
        n_feats = encoder.n_feats_out
        if in_chan is not None:
            assert in_chan == n_feats, (
                "Number of filterbank output channels"
                " and number of input channels should "
                "be the same. Received "
                f"{n_feats} and {in_chan}"
            )
        # Update in_chan
        masker = DPTransformer(
            n_feats,
            n_src,
            ff_hid=ff_hid,
            ff_activation=ff_activation,
            chunk_size=chunk_size,
            hop_size=hop_size,
            n_repeats=n_repeats,
            norm_type=norm_type,
            mask_act=mask_act,
            bidirectional=bidirectional,
            dropout=dropout,
        )
        super().__init__(encoder, masker, decoder, encoder_activation=encoder_activation)