DeepLearning101 commited on
Commit
b6c45cb
·
verified ·
1 Parent(s): 406f587

Upload 16 files

Browse files
DPTNet_eval/DPTNet_quant_sep.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DPTNet_quant_sep.py
2
+
3
+ import os
4
+ import torch
5
+ import numpy as np
6
+ import torchaudio
7
+ from huggingface_hub import hf_hub_download
8
+ from . import asteroid_test
9
+
10
+ torchaudio.set_audio_backend("sox_io")
11
+
12
+ def get_conf():
13
+ conf_filterbank = {
14
+ 'n_filters': 64,
15
+ 'kernel_size': 16,
16
+ 'stride': 8
17
+ }
18
+
19
+ conf_masknet = {
20
+ 'in_chan': 64,
21
+ 'n_src': 2,
22
+ 'out_chan': 64,
23
+ 'ff_hid': 256,
24
+ 'ff_activation': "relu",
25
+ 'norm_type': "gLN",
26
+ 'chunk_size': 100,
27
+ 'hop_size': 50,
28
+ 'n_repeats': 2,
29
+ 'mask_act': 'sigmoid',
30
+ 'bidirectional': True,
31
+ 'dropout': 0
32
+ }
33
+ return conf_filterbank, conf_masknet
34
+
35
+
36
+ def load_dpt_model():
37
+ print('Load Separation Model...')
38
+
39
+ # 從環境變數取得 Hugging Face Token
40
+ HF_TOKEN = os.getenv("HF_TOKEN")
41
+ if not HF_TOKEN:
42
+ raise EnvironmentError("環境變數 HF_TOKEN 未設定!請先執行 export HF_TOKEN=xxx")
43
+
44
+ # 從 Hugging Face Hub 下載模型權重
45
+ model_path = hf_hub_download(
46
+ repo_id="DeepLearning101/speech-separation", # ← 替換成你的 repo 名稱
47
+ filename="train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p",
48
+ token=HF_TOKEN
49
+ )
50
+
51
+ # 取得模型參數
52
+ conf_filterbank, conf_masknet = get_conf()
53
+
54
+ # 建立模型架構
55
+ model_class = getattr(asteroid_test, "DPTNet")
56
+ model = model_class(**conf_filterbank, **conf_masknet)
57
+
58
+ # 套用量化設定
59
+ model = torch.quantization.quantize_dynamic(
60
+ model,
61
+ {torch.nn.LSTM, torch.nn.Linear},
62
+ dtype=torch.qint8
63
+ )
64
+
65
+ # 載入權重(忽略不匹配的 keys)
66
+ state_dict = torch.load(model_path, map_location="cpu")
67
+ model_state_dict = model.state_dict()
68
+ filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
69
+ model.load_state_dict(filtered_state_dict, strict=False)
70
+ model.eval()
71
+
72
+ return model
73
+
74
+
75
+ def dpt_sep_process(wav_path, model=None, outfilename=None):
76
+ if model is None:
77
+ model = load_dpt_model()
78
+
79
+ x, sr = torchaudio.load(wav_path)
80
+ x = x.cpu()
81
+
82
+ with torch.no_grad():
83
+ est_sources = model(x) # shape: (1, 2, T)
84
+
85
+ est_sources = est_sources.squeeze(0) # shape: (2, T)
86
+ sep_1, sep_2 = est_sources # 拆成兩個 (T,) 的 tensor
87
+
88
+ # 正規化
89
+ max_abs = x[0].abs().max().item()
90
+ sep_1 = sep_1 * max_abs / sep_1.abs().max().item()
91
+ sep_2 = sep_2 * max_abs / sep_2.abs().max().item()
92
+
93
+ # 增加 channel 維度,變為 (1, T)
94
+ sep_1 = sep_1.unsqueeze(0)
95
+ sep_2 = sep_2.unsqueeze(0)
96
+
97
+ # 儲存結果
98
+ if outfilename is not None:
99
+ torchaudio.save(outfilename.replace('.wav', '_sep1.wav'), sep_1, sr)
100
+ torchaudio.save(outfilename.replace('.wav', '_sep2.wav'), sep_2, sr)
101
+ torchaudio.save(outfilename.replace('.wav', '_mix.wav'), x, sr)
102
+ else:
103
+ torchaudio.save(wav_path.replace('.wav', '_sep1.wav'), sep_1, sr)
104
+ torchaudio.save(wav_path.replace('.wav', '_sep2.wav'), sep_2, sr)
105
+
106
+
107
+ if __name__ == '__main__':
108
+ print("This module should be used via Flask or Gradio.")
DPTNet_eval/asteroid_test/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+
3
+ from .models import DPTNet
4
+ from .utils import torch_utils # noqa
5
+
6
+ project_root = str(pathlib.Path(__file__).expanduser().absolute().parent.parent)
7
+ __version__ = "0.3.4"
8
+
9
+
10
+ def show_available_models():
11
+ from .utils.hub_utils import MODELS_URLS_HASHTABLE
12
+
13
+ print(" \n".join(list(MODELS_URLS_HASHTABLE.keys())))
14
+
15
+
16
+ __all__ = [
17
+ "DPTNet",
18
+ "show_available_models",
19
+ ]
DPTNet_eval/asteroid_test/dsp/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .overlap_add import DualPathProcessing
2
+
3
+ __all__ = [
4
+ "DualPathProcessing",
5
+ ]
DPTNet_eval/asteroid_test/dsp/overlap_add.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from scipy.signal import get_window
3
+ # from asteroid_test.losses import PITLossWrapper
4
+ from torch import nn
5
+
6
+ '''
7
+ class LambdaOverlapAdd(torch.nn.Module):
8
+ """Overlap-add with lambda transform on segments.
9
+
10
+ Segment input signal, apply lambda function (a neural network for example)
11
+ and combine with OLA.
12
+
13
+ Args:
14
+ nnet (callable): Function to apply to each segment.
15
+ n_src (int): Number of sources in the output of nnet.
16
+ window_size (int): Size of segmenting window.
17
+ hop_size (int): Segmentation hop size.
18
+ window (str): Name of the window (see scipy.signal.get_window) used
19
+ for the synthesis.
20
+ reorder_chunks (bool): Whether to reorder each consecutive segment.
21
+ This might be useful when `nnet` is permutation invariant, as
22
+ source assignements might change output channel from one segment
23
+ to the next (in classic speech separation for example).
24
+ Reordering is performed based on the correlation between
25
+ the overlapped part of consecutive segment.
26
+
27
+ Examples:
28
+ >>> from asteroid_test import ConvTasNet
29
+ >>> nnet = ConvTasNet(n_src=2)
30
+ >>> continuous_nnet = LambdaOverlapAdd(
31
+ >>> nnet=nnet,
32
+ >>> n_src=2,
33
+ >>> window_size=64000,
34
+ >>> hop_size=None,
35
+ >>> window="hanning",
36
+ >>> reorder_chunks=True,
37
+ >>> enable_grad=False,
38
+ >>> )
39
+ >>> wav = torch.randn(1, 1, 500000)
40
+ >>> out_wavs = continuous_nnet.forward(wav)
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ nnet,
46
+ n_src,
47
+ window_size,
48
+ hop_size=None,
49
+ window="hanning",
50
+ reorder_chunks=True,
51
+ enable_grad=False,
52
+ ):
53
+ super().__init__()
54
+ assert window_size % 2 == 0, "Window size must be even"
55
+
56
+ self.nnet = nnet
57
+ self.window_size = window_size
58
+ self.hop_size = hop_size if hop_size is not None else window_size // 2
59
+ self.n_src = n_src
60
+
61
+ if window:
62
+ window = get_window(window, self.window_size).astype("float32")
63
+ window = torch.from_numpy(window)
64
+ self.use_window = True
65
+ else:
66
+ self.use_window = False
67
+
68
+ self.register_buffer("window", window)
69
+ self.reorder_chunks = reorder_chunks
70
+ self.enable_grad = enable_grad
71
+
72
+ def ola_forward(self, x):
73
+ """Heart of the class: segment signal, apply func, combine with OLA."""
74
+
75
+ assert x.ndim == 3
76
+
77
+ batch, channels, n_frames = x.size()
78
+ # Overlap and add:
79
+ # [batch, chans, n_frames] -> [batch, chans, win_size, n_chunks]
80
+ unfolded = torch.nn.functional.unfold(
81
+ x.unsqueeze(-1),
82
+ kernel_size=(self.window_size, 1),
83
+ padding=(self.window_size, 0),
84
+ stride=(self.hop_size, 1),
85
+ )
86
+
87
+ out = []
88
+ n_chunks = unfolded.shape[-1]
89
+ for frame_idx in range(n_chunks): # for loop to spare memory
90
+ frame = self.nnet(unfolded[..., frame_idx])
91
+ # user must handle multichannel by reshaping to batch
92
+ if frame_idx == 0:
93
+ assert frame.ndim == 3, "nnet should return (batch, n_src, time)"
94
+ assert frame.shape[1] == self.n_src, "nnet should return (batch, n_src, time)"
95
+ frame = frame.reshape(batch * self.n_src, -1)
96
+
97
+ if frame_idx != 0 and self.reorder_chunks:
98
+ # we determine best perm based on xcorr with previous sources
99
+ frame = _reorder_sources(
100
+ frame, out[-1], self.n_src, self.window_size, self.hop_size
101
+ )
102
+
103
+ if self.use_window:
104
+ frame = frame * self.window
105
+ else:
106
+ frame = frame / (self.window_size / self.hop_size)
107
+ out.append(frame)
108
+
109
+ out = torch.stack(out).reshape(n_chunks, batch * self.n_src, self.window_size)
110
+ out = out.permute(1, 2, 0)
111
+
112
+ out = torch.nn.functional.fold(
113
+ out,
114
+ (n_frames, 1),
115
+ kernel_size=(self.window_size, 1),
116
+ padding=(self.window_size, 0),
117
+ stride=(self.hop_size, 1),
118
+ )
119
+ return out.squeeze(-1).reshape(batch, self.n_src, -1)
120
+
121
+ def forward(self, x):
122
+ """Forward module: segment signal, apply func, combine with OLA.
123
+
124
+ Args:
125
+ x (:class:`torch.Tensor`): waveform signal of shape (batch, 1, time).
126
+
127
+ Returns:
128
+ :class:`torch.Tensor`: The output of the lambda OLA.
129
+ """
130
+ # Here we can do the reshaping
131
+ with torch.autograd.set_grad_enabled(self.enable_grad):
132
+ olad = self.ola_forward(x)
133
+ return olad
134
+
135
+
136
+ def _reorder_sources(
137
+ current: torch.FloatTensor,
138
+ previous: torch.FloatTensor,
139
+ n_src: int,
140
+ window_size: int,
141
+ hop_size: int,
142
+ ):
143
+ """
144
+ Reorder sources in current chunk to maximize correlation with previous chunk.
145
+ Used for Continuous Source Separation. Standard dsp correlation is used
146
+ for reordering.
147
+
148
+
149
+ Args:
150
+ current (:class:`torch.Tensor`): current chunk, tensor
151
+ of shape (batch, n_src, window_size)
152
+ previous (:class:`torch.Tensor`): previous chunk, tensor
153
+ of shape (batch, n_src, window_size)
154
+ n_src (:class:`int`): number of sources.
155
+ window_size (:class:`int`): window_size, equal to last dimension of
156
+ both current and previous.
157
+ hop_size (:class:`int`): hop_size between current and previous tensors.
158
+
159
+ Returns:
160
+ current:
161
+
162
+ """
163
+ batch, frames = current.size()
164
+ current = current.reshape(-1, n_src, frames)
165
+ previous = previous.reshape(-1, n_src, frames)
166
+
167
+ overlap_f = window_size - hop_size
168
+
169
+ def reorder_func(x, y):
170
+ x = x[..., :overlap_f]
171
+ y = y[..., -overlap_f:]
172
+ # Mean normalization
173
+ x = x - x.mean(-1, keepdim=True)
174
+ y = y - y.mean(-1, keepdim=True)
175
+ # Negative mean Correlation
176
+ return -torch.sum(x.unsqueeze(1) * y.unsqueeze(2), dim=-1)
177
+
178
+ # We maximize correlation-like between previous and current.
179
+ pit = PITLossWrapper(reorder_func)
180
+ current = pit(current, previous, return_est=True)[1]
181
+ return current.reshape(batch, frames)
182
+ '''
183
+
184
+
185
+ class DualPathProcessing(nn.Module):
186
+ """Perform Dual-Path processing via overlap-add as in DPRNN [1].
187
+
188
+ Args:
189
+ chunk_size (int): Size of segmenting window.
190
+ hop_size (int): segmentation hop size.
191
+
192
+ References:
193
+ [1] "Dual-path RNN: efficient long sequence modeling for
194
+ time-domain single-channel speech separation", Yi Luo, Zhuo Chen
195
+ and Takuya Yoshioka. https://arxiv.org/abs/1910.06379
196
+ """
197
+
198
+ def __init__(self, chunk_size, hop_size):
199
+ super(DualPathProcessing, self).__init__()
200
+ self.chunk_size = chunk_size
201
+ self.hop_size = hop_size
202
+ self.n_orig_frames = None
203
+
204
+ def unfold(self, x):
205
+ """Unfold the feature tensor from
206
+
207
+ (batch, channels, time) to (batch, channels, chunk_size, n_chunks).
208
+
209
+ Args:
210
+ x: (:class:`torch.Tensor`): feature tensor of shape (batch, channels, time).
211
+
212
+ Returns:
213
+ x: (:class:`torch.Tensor`): spliced feature tensor of shape
214
+ (batch, channels, chunk_size, n_chunks).
215
+
216
+ """
217
+ # x is (batch, chan, frames)
218
+ batch, chan, frames = x.size()
219
+ assert x.ndim == 3
220
+ self.n_orig_frames = x.shape[-1]
221
+ unfolded = torch.nn.functional.unfold(
222
+ x.unsqueeze(-1),
223
+ kernel_size=(self.chunk_size, 1),
224
+ padding=(self.chunk_size, 0),
225
+ stride=(self.hop_size, 1),
226
+ )
227
+
228
+ return unfolded.reshape(
229
+ batch, chan, self.chunk_size, -1
230
+ ) # (batch, chan, chunk_size, n_chunks)
231
+
232
+ def fold(self, x, output_size=None):
233
+ """Folds back the spliced feature tensor.
234
+
235
+ Input shape (batch, channels, chunk_size, n_chunks) to original shape
236
+ (batch, channels, time) using overlap-add.
237
+
238
+ Args:
239
+ x: (:class:`torch.Tensor`): spliced feature tensor of shape
240
+ (batch, channels, chunk_size, n_chunks).
241
+ output_size: (int, optional): sequence length of original feature tensor.
242
+ If None, the original length cached by the previous call of `unfold`
243
+ will be used.
244
+
245
+ Returns:
246
+ x: (:class:`torch.Tensor`): feature tensor of shape (batch, channels, time).
247
+
248
+ .. note:: `fold` caches the original length of the pr
249
+
250
+ """
251
+ output_size = output_size if output_size is not None else self.n_orig_frames
252
+ # x is (batch, chan, chunk_size, n_chunks)
253
+ batch, chan, chunk_size, n_chunks = x.size()
254
+ to_unfold = x.reshape(batch, chan * self.chunk_size, n_chunks)
255
+ x = torch.nn.functional.fold(
256
+ to_unfold,
257
+ (output_size, 1),
258
+ kernel_size=(self.chunk_size, 1),
259
+ padding=(self.chunk_size, 0),
260
+ stride=(self.hop_size, 1),
261
+ )
262
+
263
+ x /= self.chunk_size / self.hop_size
264
+
265
+ return x.reshape(batch, chan, self.n_orig_frames)
266
+
267
+ @staticmethod
268
+ def intra_process(x, module):
269
+ """Performs intra-chunk processing.
270
+
271
+ Args:
272
+ x (:class:`torch.Tensor`): spliced feature tensor of shape
273
+ (batch, channels, chunk_size, n_chunks).
274
+ module (:class:`torch.nn.Module`): module one wish to apply to each chunk
275
+ of the spliced feature tensor.
276
+
277
+
278
+ Returns:
279
+ x (:class:`torch.Tensor`): processed spliced feature tensor of shape
280
+ (batch, channels, chunk_size, n_chunks).
281
+
282
+ .. note:: the module should have the channel first convention and accept
283
+ a 3D tensor of shape (batch, channels, time).
284
+ """
285
+
286
+ # x is (batch, channels, chunk_size, n_chunks)
287
+ batch, channels, chunk_size, n_chunks = x.size()
288
+ # we reshape to batch*chunk_size, channels, n_chunks
289
+ x = x.transpose(1, -1).reshape(batch * n_chunks, chunk_size, channels).transpose(1, -1)
290
+ x = module(x)
291
+ x = x.reshape(batch, n_chunks, channels, chunk_size).transpose(1, -1).transpose(1, 2)
292
+ return x
293
+
294
+ @staticmethod
295
+ def inter_process(x, module):
296
+ """Performs inter-chunk processing.
297
+
298
+ Args:
299
+ x (:class:`torch.Tensor`): spliced feature tensor of shape
300
+ (batch, channels, chunk_size, n_chunks).
301
+ module (:class:`torch.nn.Module`): module one wish to apply between
302
+ each chunk of the spliced feature tensor.
303
+
304
+
305
+ Returns:
306
+ x (:class:`torch.Tensor`): processed spliced feature tensor of shape
307
+ (batch, channels, chunk_size, n_chunks).
308
+
309
+ .. note:: the module should have the channel first convention and accept
310
+ a 3D tensor of shape (batch, channels, time).
311
+ """
312
+
313
+ batch, channels, chunk_size, n_chunks = x.size()
314
+ x = x.transpose(1, 2).reshape(batch * chunk_size, channels, n_chunks)
315
+ x = module(x)
316
+ x = x.reshape(batch, chunk_size, channels, n_chunks).transpose(1, 2)
317
+ return x
DPTNet_eval/asteroid_test/filterbanks/__init__.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from .analytic_free_fb import AnalyticFreeFB
2
+ from .free_fb import FreeFB
3
+ from .enc_dec import Filterbank, Encoder, Decoder
4
+
5
+
6
+ def make_enc_dec(
7
+ fb_name,
8
+ n_filters,
9
+ kernel_size,
10
+ stride=None,
11
+ who_is_pinv=None,
12
+ padding=0,
13
+ output_padding=0,
14
+ **kwargs,
15
+ ):
16
+ """Creates congruent encoder and decoder from the same filterbank family.
17
+
18
+ Args:
19
+ fb_name (str, className): Filterbank family from which to make encoder
20
+ and decoder. To choose among [``'free'``, ``'analytic_free'``,
21
+ ``'param_sinc'``, ``'stft'``]. Can also be a class defined in a
22
+ submodule in this subpackade (e.g. :class:`~.FreeFB`).
23
+ n_filters (int): Number of filters.
24
+ kernel_size (int): Length of the filters.
25
+ stride (int, optional): Stride of the convolution.
26
+ If None (default), set to ``kernel_size // 2``.
27
+ who_is_pinv (str, optional): If `None`, no pseudo-inverse filters will
28
+ be used. If string (among [``'encoder'``, ``'decoder'``]), decides
29
+ which of ``Encoder`` or ``Decoder`` will be the pseudo inverse of
30
+ the other one.
31
+ padding (int): Zero-padding added to both sides of the input.
32
+ Passed to Encoder and Decoder.
33
+ output_padding (int): Additional size added to one side of the output shape.
34
+ Passed to Decoder.
35
+ **kwargs: Arguments which will be passed to the filterbank class
36
+ additionally to the usual `n_filters`, `kernel_size` and `stride`.
37
+ Depends on the filterbank family.
38
+ Returns:
39
+ :class:`.Encoder`, :class:`.Decoder`
40
+ """
41
+ fb_class = get(fb_name)
42
+
43
+ if who_is_pinv in ["dec", "decoder"]:
44
+ fb = fb_class(n_filters, kernel_size, stride=stride, **kwargs)
45
+ enc = Encoder(fb, padding=padding)
46
+ # Decoder filterbank is pseudo inverse of encoder filterbank.
47
+ dec = Decoder.pinv_of(fb)
48
+ elif who_is_pinv in ["enc", "encoder"]:
49
+ fb = fb_class(n_filters, kernel_size, stride=stride, **kwargs)
50
+ dec = Decoder(fb, padding=padding, output_padding=output_padding)
51
+ # Encoder filterbank is pseudo inverse of decoder filterbank.
52
+ enc = Encoder.pinv_of(fb)
53
+ else:
54
+ fb = fb_class(n_filters, kernel_size, stride=stride, **kwargs)
55
+ enc = Encoder(fb, padding=padding)
56
+ # Filters between encoder and decoder should not be shared.
57
+ fb = fb_class(n_filters, kernel_size, stride=stride, **kwargs)
58
+ dec = Decoder(fb, padding=padding, output_padding=output_padding)
59
+ return enc, dec
60
+
61
+
62
+ def register_filterbank(custom_fb):
63
+ """Register a custom filterbank, gettable with `filterbanks.get`.
64
+
65
+ Args:
66
+ custom_fb: Custom filterbank to register.
67
+
68
+ """
69
+ if custom_fb.__name__ in globals().keys() or custom_fb.__name__.lower() in globals().keys():
70
+ raise ValueError(f"Filterbank {custom_fb.__name__} already exists. Choose another name.")
71
+ globals().update({custom_fb.__name__: custom_fb})
72
+
73
+
74
+ def get(identifier):
75
+ """Returns a filterbank class from a string. Returns its input if it
76
+ is callable (already a :class:`.Filterbank` for example).
77
+
78
+ Args:
79
+ identifier (str or Callable or None): the filterbank identifier.
80
+
81
+ Returns:
82
+ :class:`.Filterbank` or None
83
+ """
84
+ if identifier is None:
85
+ return None
86
+ elif callable(identifier):
87
+ return identifier
88
+ elif isinstance(identifier, str):
89
+ cls = globals().get(identifier)
90
+ if cls is None:
91
+ raise ValueError("Could not interpret filterbank identifier: " + str(identifier))
92
+ return cls
93
+ else:
94
+ raise ValueError("Could not interpret filterbank identifier: " + str(identifier))
95
+
96
+
97
+ # Aliases.
98
+ free = FreeFB
99
+
100
+ # For the docs
101
+ __all__ = [
102
+ "Filterbank",
103
+ "Encoder",
104
+ "Decoder",
105
+ "FreeFB",
106
+ "make_enc_dec",
107
+ ]
DPTNet_eval/asteroid_test/filterbanks/enc_dec.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+
7
+ class Filterbank(nn.Module):
8
+ """Base Filterbank class.
9
+ Each subclass has to implement a `filters` property.
10
+
11
+ Args:
12
+ n_filters (int): Number of filters.
13
+ kernel_size (int): Length of the filters.
14
+ stride (int, optional): Stride of the conv or transposed conv. (Hop size).
15
+ If None (default), set to ``kernel_size // 2``.
16
+
17
+ Attributes:
18
+ n_feats_out (int): Number of output filters.
19
+ """
20
+
21
+ def __init__(self, n_filters, kernel_size, stride=None):
22
+ super(Filterbank, self).__init__()
23
+ self.n_filters = n_filters
24
+ self.kernel_size = kernel_size
25
+ self.stride = stride if stride else self.kernel_size // 2
26
+ # If not specified otherwise in the filterbank's init, output
27
+ # number of features is equal to number of required filters.
28
+ self.n_feats_out = n_filters
29
+
30
+ @property
31
+ def filters(self):
32
+ """ Abstract method for filters. """
33
+ raise NotImplementedError
34
+
35
+ def get_config(self):
36
+ """ Returns dictionary of arguments to re-instantiate the class. """
37
+ config = {
38
+ "fb_name": self.__class__.__name__,
39
+ "n_filters": self.n_filters,
40
+ "kernel_size": self.kernel_size,
41
+ "stride": self.stride,
42
+ }
43
+ return config
44
+
45
+
46
+ class _EncDec(nn.Module):
47
+ """Base private class for Encoder and Decoder.
48
+
49
+ Common parameters and methods.
50
+
51
+ Args:
52
+ filterbank (:class:`Filterbank`): Filterbank instance. The filterbank
53
+ to use as an encoder or a decoder.
54
+ is_pinv (bool): Whether to be the pseudo inverse of filterbank.
55
+
56
+ Attributes:
57
+ filterbank (:class:`Filterbank`)
58
+ stride (int)
59
+ is_pinv (bool)
60
+ """
61
+
62
+ def __init__(self, filterbank, is_pinv=False):
63
+ super(_EncDec, self).__init__()
64
+ self.filterbank = filterbank
65
+ self.stride = self.filterbank.stride
66
+ self.is_pinv = is_pinv
67
+
68
+ @property
69
+ def filters(self):
70
+ return self.filterbank.filters
71
+
72
+ def compute_filter_pinv(self, filters):
73
+ """ Computes pseudo inverse filterbank of given filters."""
74
+ scale = self.filterbank.stride / self.filterbank.kernel_size
75
+ shape = filters.shape
76
+ ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape)
77
+ # Compensate for the overlap-add.
78
+ return ifilt * scale
79
+
80
+ def get_filters(self):
81
+ """ Returns filters or pinv filters depending on `is_pinv` attribute """
82
+ if self.is_pinv:
83
+ return self.compute_filter_pinv(self.filters)
84
+ else:
85
+ return self.filters
86
+
87
+ def get_config(self):
88
+ """ Returns dictionary of arguments to re-instantiate the class."""
89
+ config = {"is_pinv": self.is_pinv}
90
+ base_config = self.filterbank.get_config()
91
+ return dict(list(base_config.items()) + list(config.items()))
92
+
93
+
94
+ class Encoder(_EncDec):
95
+ """Encoder class.
96
+
97
+ Add encoding methods to Filterbank classes.
98
+ Not intended to be subclassed.
99
+
100
+ Args:
101
+ filterbank (:class:`Filterbank`): The filterbank to use
102
+ as an encoder.
103
+ is_pinv (bool): Whether to be the pseudo inverse of filterbank.
104
+ as_conv1d (bool): Whether to behave like nn.Conv1d.
105
+ If True (default), forwarding input with shape (batch, 1, time)
106
+ will output a tensor of shape (batch, freq, conv_time).
107
+ If False, will output a tensor of shape (batch, 1, freq, conv_time).
108
+ padding (int): Zero-padding added to both sides of the input.
109
+
110
+ """
111
+
112
+ def __init__(self, filterbank, is_pinv=False, as_conv1d=True, padding=0):
113
+ super(Encoder, self).__init__(filterbank, is_pinv=is_pinv)
114
+ self.as_conv1d = as_conv1d
115
+ self.n_feats_out = self.filterbank.n_feats_out
116
+ self.padding = padding
117
+
118
+ @classmethod
119
+ def pinv_of(cls, filterbank, **kwargs):
120
+ """Returns an :class:`~.Encoder`, pseudo inverse of a
121
+ :class:`~.Filterbank` or :class:`~.Decoder`."""
122
+ if isinstance(filterbank, Filterbank):
123
+ return cls(filterbank, is_pinv=True, **kwargs)
124
+ elif isinstance(filterbank, Decoder):
125
+ return cls(filterbank.filterbank, is_pinv=True, **kwargs)
126
+
127
+ def forward(self, waveform):
128
+ """Convolve input waveform with the filters from a filterbank.
129
+ Args:
130
+ waveform (:class:`torch.Tensor`): any tensor with samples along the
131
+ last dimension. The waveform representation with and
132
+ batch/channel etc.. dimension.
133
+ Returns:
134
+ :class:`torch.Tensor`: The corresponding TF domain signal.
135
+
136
+ Shapes:
137
+ >>> (time, ) --> (freq, conv_time)
138
+ >>> (batch, time) --> (batch, freq, conv_time) # Avoid
139
+ >>> if as_conv1d:
140
+ >>> (batch, 1, time) --> (batch, freq, conv_time)
141
+ >>> (batch, chan, time) --> (batch, chan, freq, conv_time)
142
+ >>> else:
143
+ >>> (batch, chan, time) --> (batch, chan, freq, conv_time)
144
+ >>> (batch, any, dim, time) --> (batch, any, dim, freq, conv_time)
145
+ """
146
+ filters = self.get_filters()
147
+ if waveform.ndim == 1:
148
+ # Assumes 1D input with shape (time,)
149
+ # Output will be (freq, conv_time)
150
+ return F.conv1d(
151
+ waveform[None, None], filters, stride=self.stride, padding=self.padding
152
+ ).squeeze()
153
+ elif waveform.ndim == 2:
154
+ # Assume 2D input with shape (batch or channels, time)
155
+ # Output will be (batch or channels, freq, conv_time)
156
+ warnings.warn(
157
+ "Input tensor was 2D. Applying the corresponding "
158
+ "Decoder to the current output will result in a 3D "
159
+ "tensor. This behaviours was introduced to match "
160
+ "Conv1D and ConvTranspose1D, please use 3D inputs "
161
+ "to avoid it. For example, this can be done with "
162
+ "input_tensor.unsqueeze(1)."
163
+ )
164
+ return F.conv1d(
165
+ waveform.unsqueeze(1), filters, stride=self.stride, padding=self.padding
166
+ )
167
+ elif waveform.ndim == 3:
168
+ batch, channels, time_len = waveform.shape
169
+ if channels == 1 and self.as_conv1d:
170
+ # That's the common single channel case (batch, 1, time)
171
+ # Output will be (batch, freq, stft_time), behaves as Conv1D
172
+ return F.conv1d(waveform, filters, stride=self.stride, padding=self.padding)
173
+ else:
174
+ # Return batched convolution, input is (batch, 3, time),
175
+ # output will be (batch, 3, freq, conv_time).
176
+ # Useful for multichannel transforms
177
+ # If as_conv1d is false, (batch, 1, time) will output
178
+ # (batch, 1, freq, conv_time), useful for consistency.
179
+ return self.batch_1d_conv(waveform, filters)
180
+ else: # waveform.ndim > 3
181
+ # This is to compute "multi"multichannel convolution.
182
+ # Input can be (*, time), output will be (*, freq, conv_time)
183
+ return self.batch_1d_conv(waveform, filters)
184
+
185
+ def batch_1d_conv(self, inp, filters):
186
+ # Here we perform multichannel / multi-source convolution. Ou
187
+ # Output should be (batch, channels, freq, conv_time)
188
+ batched_conv = F.conv1d(
189
+ inp.view(-1, 1, inp.shape[-1]), filters, stride=self.stride, padding=self.padding
190
+ )
191
+ output_shape = inp.shape[:-1] + batched_conv.shape[-2:]
192
+ return batched_conv.view(output_shape)
193
+
194
+
195
+ class Decoder(_EncDec):
196
+ """Decoder class.
197
+
198
+ Add decoding methods to Filterbank classes.
199
+ Not intended to be subclassed.
200
+
201
+ Args:
202
+ filterbank (:class:`Filterbank`): The filterbank to use as an decoder.
203
+ is_pinv (bool): Whether to be the pseudo inverse of filterbank.
204
+ padding (int): Zero-padding added to both sides of the input.
205
+ output_padding (int): Additional size added to one side of the
206
+ output shape.
207
+
208
+ Notes
209
+ `padding` and `output_padding` arguments are directly passed to
210
+ F.conv_transpose1d.
211
+ """
212
+
213
+ def __init__(self, filterbank, is_pinv=False, padding=0, output_padding=0):
214
+ super().__init__(filterbank, is_pinv=is_pinv)
215
+ self.padding = padding
216
+ self.output_padding = output_padding
217
+
218
+ @classmethod
219
+ def pinv_of(cls, filterbank):
220
+ """ Returns an Decoder, pseudo inverse of a filterbank or Encoder."""
221
+ if isinstance(filterbank, Filterbank):
222
+ return cls(filterbank, is_pinv=True)
223
+ elif isinstance(filterbank, Encoder):
224
+ return cls(filterbank.filterbank, is_pinv=True)
225
+
226
+ def forward(self, spec):
227
+ """Applies transposed convolution to a TF representation.
228
+
229
+ This is equivalent to overlap-add.
230
+
231
+ Args:
232
+ spec (:class:`torch.Tensor`): 3D or 4D Tensor. The TF
233
+ representation. (Output of :func:`Encoder.forward`).
234
+ Returns:
235
+ :class:`torch.Tensor`: The corresponding time domain signal.
236
+ """
237
+ filters = self.get_filters()
238
+ if spec.ndim == 2:
239
+ # Input is (freq, conv_time), output is (time)
240
+ return F.conv_transpose1d(
241
+ spec.unsqueeze(0),
242
+ filters,
243
+ stride=self.stride,
244
+ padding=self.padding,
245
+ output_padding=self.output_padding,
246
+ ).squeeze()
247
+ if spec.ndim == 3:
248
+ # Input is (batch, freq, conv_time), output is (batch, 1, time)
249
+ return F.conv_transpose1d(
250
+ spec,
251
+ filters,
252
+ stride=self.stride,
253
+ padding=self.padding,
254
+ output_padding=self.output_padding,
255
+ )
256
+ elif spec.ndim > 3:
257
+ # Multiply all the left dimensions together and group them in the
258
+ # batch. Make the convolution and restore.
259
+ view_as = (-1,) + spec.shape[-2:]
260
+ out = F.conv_transpose1d(
261
+ spec.view(view_as),
262
+ filters,
263
+ stride=self.stride,
264
+ padding=self.padding,
265
+ output_padding=self.output_padding,
266
+ )
267
+ return out.view(spec.shape[:-2] + (-1,))
DPTNet_eval/asteroid_test/filterbanks/free_fb.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .enc_dec import Filterbank
4
+
5
+
6
+ class FreeFB(Filterbank):
7
+ """Free filterbank without any constraints. Equivalent to
8
+ :class:`nn.Conv1d`.
9
+
10
+ Args:
11
+ n_filters (int): Number of filters.
12
+ kernel_size (int): Length of the filters.
13
+ stride (int, optional): Stride of the convolution.
14
+ If None (default), set to ``kernel_size // 2``.
15
+
16
+ Attributes:
17
+ n_feats_out (int): Number of output filters.
18
+
19
+ References:
20
+ [1] : "Filterbank design for end-to-end speech separation".
21
+ Submitted to ICASSP 2020. Manuel Pariente, Samuele Cornell,
22
+ Antoine Deleforge, Emmanuel Vincent.
23
+ """
24
+
25
+ def __init__(self, n_filters, kernel_size, stride=None, **kwargs):
26
+ super(FreeFB, self).__init__(n_filters, kernel_size, stride=stride)
27
+ self._filters = nn.Parameter(torch.ones(n_filters, 1, kernel_size))
28
+ for p in self.parameters():
29
+ nn.init.xavier_normal_(p)
30
+
31
+ @property
32
+ def filters(self):
33
+ return self._filters
DPTNet_eval/asteroid_test/masknn/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from .convolutional import TDConvNet, TDConvNetpp, SuDORMRF, SuDORMRFImproved
2
+ # from .recurrent import DPRNN, LSTMMasker
3
+ from .attention import DPTransformer
4
+
5
+ __all__ = [
6
+ # "TDConvNet",
7
+ # "DPRNN",
8
+ "DPTransformer",
9
+ # "LSTMMasker",
10
+ # "SuDORMRF",
11
+ # "SuDORMRFImproved",
12
+ ]
DPTNet_eval/asteroid_test/masknn/activations.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class Swish(nn.Module):
7
+ def __init__(self):
8
+ super(Swish, self).__init__()
9
+
10
+ def forward(self, x):
11
+ return x * torch.sigmoid(x)
12
+
13
+
14
+ def linear():
15
+ return nn.Identity()
16
+
17
+
18
+ def relu():
19
+ return nn.ReLU()
20
+
21
+
22
+ def prelu():
23
+ return nn.PReLU()
24
+
25
+
26
+ def leaky_relu():
27
+ return nn.LeakyReLU()
28
+
29
+
30
+ def sigmoid():
31
+ return nn.Sigmoid()
32
+
33
+
34
+ def softmax(dim=None):
35
+ return nn.Softmax(dim=dim)
36
+
37
+
38
+ def tanh():
39
+ return nn.Tanh()
40
+
41
+
42
+ def gelu():
43
+ return nn.GELU()
44
+
45
+
46
+ def swish():
47
+ return Swish()
48
+
49
+
50
+ def register_activation(custom_act):
51
+ """Register a custom activation, gettable with `activation.get`.
52
+
53
+ Args:
54
+ custom_act: Custom activation function to register.
55
+
56
+ """
57
+ if custom_act.__name__ in globals().keys() or custom_act.__name__.lower() in globals().keys():
58
+ raise ValueError(f"Activation {custom_act.__name__} already exists. Choose another name.")
59
+ globals().update({custom_act.__name__: custom_act})
60
+
61
+
62
+ def get(identifier):
63
+ """Returns an activation function from a string. Returns its input if it
64
+ is callable (already an activation for example).
65
+
66
+ Args:
67
+ identifier (str or Callable or None): the activation identifier.
68
+
69
+ Returns:
70
+ :class:`nn.Module` or None
71
+ """
72
+ if identifier is None:
73
+ return None
74
+ elif callable(identifier):
75
+ return identifier
76
+ elif isinstance(identifier, str):
77
+ cls = globals().get(identifier)
78
+ if cls is None:
79
+ raise ValueError("Could not interpret activation identifier: " + str(identifier))
80
+ return cls
81
+ else:
82
+ raise ValueError("Could not interpret activation identifier: " + str(identifier))
DPTNet_eval/asteroid_test/masknn/attention.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import ceil
2
+ import warnings
3
+
4
+ import torch.nn as nn
5
+ from torch.nn.modules.activation import MultiheadAttention
6
+ from ..masknn import activations, norms
7
+ import torch
8
+ from ..dsp.overlap_add import DualPathProcessing
9
+
10
+ import inspect
11
+
12
+
13
+ class ImprovedTransformedLayer(nn.Module):
14
+ """
15
+ Improved Transformer module as used in [1].
16
+ It is Multi-Head self-attention followed by LSTM, activation and linear projection layer.
17
+
18
+ Args:
19
+ embed_dim (int): Number of input channels.
20
+ n_heads (int): Number of attention heads.
21
+ dim_ff (int): Number of neurons in the RNNs cell state.
22
+ Defaults to 256. RNN here replaces standard FF linear layer in plain Transformer.
23
+ dropout (float, optional): Dropout ratio, must be in [0,1].
24
+ activation (str, optional): activation function applied at the output of RNN.
25
+ bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN
26
+ (Intra-Chunk is always bidirectional).
27
+ norm_type (str, optional): Type of normalization to use.
28
+
29
+ References:
30
+ [1] Chen, Jingjing, Qirong Mao, and Dong Liu.
31
+ "Dual-Path Transformer Network: Direct Context-Aware Modeling for End-to-End Monaural Speech Separation."
32
+ arXiv preprint arXiv:2007.13975 (2020).
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ embed_dim,
38
+ n_heads,
39
+ dim_ff,
40
+ dropout=0.0,
41
+ activation="relu",
42
+ bidirectional=True,
43
+ norm="gLN",
44
+ ):
45
+ super(ImprovedTransformedLayer, self).__init__()
46
+
47
+ self.mha = MultiheadAttention(embed_dim, n_heads, dropout=dropout)
48
+ # self.linear_first = nn.Linear(embed_dim, 2 * dim_ff) # Added by Kay. 20201119
49
+ self.dropout = nn.Dropout(dropout)
50
+ self.recurrent = nn.LSTM(embed_dim, dim_ff, bidirectional=bidirectional, batch_first=True)
51
+ ff_inner_dim = 2 * dim_ff if bidirectional else dim_ff
52
+ self.linear = nn.Linear(ff_inner_dim, embed_dim)
53
+ self.activation = activations.get(activation)()
54
+ self.norm_mha = norms.get(norm)(embed_dim)
55
+ self.norm_ff = norms.get(norm)(embed_dim)
56
+
57
+ def forward(self, x):
58
+ tomha = x.permute(2, 0, 1)
59
+ # x is batch, channels, seq_len
60
+ # mha is seq_len, batch, channels
61
+ # self-attention is applied
62
+ out = self.mha(tomha, tomha, tomha)[0]
63
+ x = self.dropout(out.permute(1, 2, 0)) + x
64
+ x = self.norm_mha(x)
65
+
66
+ # lstm is applied
67
+ out = self.linear(self.dropout(self.activation(self.recurrent(x.transpose(1, -1))[0])))
68
+ x = self.dropout(out.transpose(1, -1)) + x
69
+ return self.norm_ff(x)
70
+
71
+ ''' version 0.3.4
72
+ def forward(self, x):
73
+ x = x.transpose(1, -1)
74
+ # x is batch, seq_len, channels
75
+ # self-attention is applied
76
+ out = self.mha(x, x, x)[0]
77
+ x = self.dropout(out) + x
78
+ x = self.norm_mha(x.transpose(1, -1)).transpose(1, -1)
79
+
80
+ # lstm is applied
81
+ out = self.linear(self.dropout(self.activation(self.recurrent(x)[0])))
82
+ # out = self.linear(self.dropout(self.activation(self.linear_first(x)[0])))
83
+ x = self.dropout(out) + x
84
+ return self.norm_ff(x.transpose(1, -1))
85
+ '''
86
+
87
+
88
+ class DPTransformer(nn.Module):
89
+ """Dual-path Transformer introduced in [1].
90
+
91
+ Args:
92
+ in_chan (int): Number of input filters.
93
+ n_src (int): Number of masks to estimate.
94
+ n_heads (int): Number of attention heads.
95
+ ff_hid (int): Number of neurons in the RNNs cell state.
96
+ Defaults to 256.
97
+ chunk_size (int): window size of overlap and add processing.
98
+ Defaults to 100.
99
+ hop_size (int or None): hop size (stride) of overlap and add processing.
100
+ Default to `chunk_size // 2` (50% overlap).
101
+ n_repeats (int): Number of repeats. Defaults to 6.
102
+ norm_type (str, optional): Type of normalization to use.
103
+ ff_activation (str, optional): activation function applied at the output of RNN.
104
+ mask_act (str, optional): Which non-linear function to generate mask.
105
+ bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN
106
+ (Intra-Chunk is always bidirectional).
107
+ dropout (float, optional): Dropout ratio, must be in [0,1].
108
+
109
+ References
110
+ [1] Chen, Jingjing, Qirong Mao, and Dong Liu. "Dual-Path Transformer
111
+ Network: Direct Context-Aware Modeling for End-to-End Monaural Speech Separation."
112
+ arXiv (2020).
113
+ """
114
+
115
+ def __init__(
116
+ self,
117
+ in_chan,
118
+ n_src,
119
+ n_heads=4,
120
+ ff_hid=256,
121
+ chunk_size=100,
122
+ hop_size=None,
123
+ n_repeats=6,
124
+ norm_type="gLN",
125
+ ff_activation="relu",
126
+ mask_act="relu",
127
+ bidirectional=True,
128
+ dropout=0,
129
+ ):
130
+ super(DPTransformer, self).__init__()
131
+ self.in_chan = in_chan
132
+ self.n_src = n_src
133
+ self.n_heads = n_heads
134
+ self.ff_hid = ff_hid
135
+ self.chunk_size = chunk_size
136
+ hop_size = hop_size if hop_size is not None else chunk_size // 2
137
+ self.hop_size = hop_size
138
+ self.n_repeats = n_repeats
139
+ self.n_src = n_src
140
+ self.norm_type = norm_type
141
+ self.ff_activation = ff_activation
142
+ self.mask_act = mask_act
143
+ self.bidirectional = bidirectional
144
+ self.dropout = dropout
145
+
146
+ # version 0.3.4
147
+ # self.in_norm = norms.get(norm_type)(in_chan)
148
+ self.mha_in_dim = ceil(self.in_chan / self.n_heads) * self.n_heads
149
+ if self.in_chan % self.n_heads != 0:
150
+ warnings.warn(
151
+ f"DPTransformer input dim ({self.in_chan}) is not a multiple of the number of "
152
+ f"heads ({self.n_heads}). Adding extra linear layer at input to accomodate "
153
+ f"(size [{self.in_chan} x {self.mha_in_dim}])"
154
+ )
155
+ self.input_layer = nn.Linear(self.in_chan, self.mha_in_dim)
156
+ else:
157
+ self.input_layer = None
158
+
159
+ self.in_norm = norms.get(norm_type)(self.mha_in_dim)
160
+ self.ola = DualPathProcessing(self.chunk_size, self.hop_size)
161
+
162
+ # Succession of DPRNNBlocks.
163
+ self.layers = nn.ModuleList([])
164
+ for x in range(self.n_repeats):
165
+ self.layers.append(
166
+ nn.ModuleList(
167
+ [
168
+ ImprovedTransformedLayer(
169
+ self.mha_in_dim,
170
+ self.n_heads,
171
+ self.ff_hid,
172
+ self.dropout,
173
+ self.ff_activation,
174
+ True,
175
+ self.norm_type,
176
+ ),
177
+ ImprovedTransformedLayer(
178
+ self.mha_in_dim,
179
+ self.n_heads,
180
+ self.ff_hid,
181
+ self.dropout,
182
+ self.ff_activation,
183
+ self.bidirectional,
184
+ self.norm_type,
185
+ ),
186
+ ]
187
+ )
188
+ )
189
+ net_out_conv = nn.Conv2d(self.mha_in_dim, n_src * self.in_chan, 1)
190
+ self.first_out = nn.Sequential(nn.PReLU(), net_out_conv)
191
+ # Gating and masking in 2D space (after fold)
192
+ self.net_out = nn.Sequential(nn.Conv1d(self.in_chan, self.in_chan, 1), nn.Tanh())
193
+ self.net_gate = nn.Sequential(nn.Conv1d(self.in_chan, self.in_chan, 1), nn.Sigmoid())
194
+
195
+ # Get activation function.
196
+ mask_nl_class = activations.get(mask_act)
197
+ # For softmax, feed the source dimension.
198
+ if has_arg(mask_nl_class, "dim"):
199
+ self.output_act = mask_nl_class(dim=1)
200
+ else:
201
+ self.output_act = mask_nl_class()
202
+
203
+ def forward(self, mixture_w):
204
+ r"""Forward.
205
+
206
+ Args:
207
+ mixture_w (:class:`torch.Tensor`): Tensor of shape $(batch, nfilters, nframes)$
208
+
209
+ Returns:
210
+ :class:`torch.Tensor`: estimated mask of shape $(batch, nsrc, nfilters, nframes)$
211
+ """
212
+ if self.input_layer is not None:
213
+ mixture_w = self.input_layer(mixture_w.transpose(1, 2)).transpose(1, 2)
214
+ mixture_w = self.in_norm(mixture_w) # [batch, bn_chan, n_frames]
215
+ n_orig_frames = mixture_w.shape[-1]
216
+
217
+ mixture_w = self.ola.unfold(mixture_w)
218
+ batch, n_filters, self.chunk_size, n_chunks = mixture_w.size()
219
+
220
+ for layer_idx in range(len(self.layers)):
221
+ intra, inter = self.layers[layer_idx]
222
+ mixture_w = self.ola.intra_process(mixture_w, intra)
223
+ mixture_w = self.ola.inter_process(mixture_w, inter)
224
+
225
+ output = self.first_out(mixture_w)
226
+ output = output.reshape(batch * self.n_src, self.in_chan, self.chunk_size, n_chunks)
227
+ output = self.ola.fold(output, output_size=n_orig_frames)
228
+
229
+ output = self.net_out(output) * self.net_gate(output)
230
+ # Compute mask
231
+ output = output.reshape(batch, self.n_src, self.in_chan, -1)
232
+ est_mask = self.output_act(output)
233
+ return est_mask
234
+
235
+ def get_config(self):
236
+ config = {
237
+ "in_chan": self.in_chan,
238
+ "ff_hid": self.ff_hid,
239
+ "n_heads": self.n_heads,
240
+ "chunk_size": self.chunk_size,
241
+ "hop_size": self.hop_size,
242
+ "n_repeats": self.n_repeats,
243
+ "n_src": self.n_src,
244
+ "norm_type": self.norm_type,
245
+ "ff_activation": self.ff_activation,
246
+ "mask_act": self.mask_act,
247
+ "bidirectional": self.bidirectional,
248
+ "dropout": self.dropout,
249
+ }
250
+ return config
251
+
252
+
253
+ def has_arg(fn, name):
254
+ """Checks if a callable accepts a given keyword argument.
255
+
256
+ Args:
257
+ fn (callable): Callable to inspect.
258
+ name (str): Check if `fn` can be called with `name` as a keyword
259
+ argument.
260
+
261
+ Returns:
262
+ bool: whether `fn` accepts a `name` keyword argument.
263
+ """
264
+ signature = inspect.signature(fn)
265
+ parameter = signature.parameters.get(name)
266
+ if parameter is None:
267
+ return False
268
+ return parameter.kind in (
269
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
270
+ inspect.Parameter.KEYWORD_ONLY,
271
+ )
DPTNet_eval/asteroid_test/masknn/norms.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn.modules.batchnorm import _BatchNorm
5
+
6
+ EPS = 1e-8
7
+
8
+
9
+ class _LayerNorm(nn.Module):
10
+ """Layer Normalization base class."""
11
+
12
+ def __init__(self, channel_size):
13
+ super(_LayerNorm, self).__init__()
14
+ self.channel_size = channel_size
15
+ self.gamma = nn.Parameter(torch.ones(channel_size), requires_grad=True)
16
+ self.beta = nn.Parameter(torch.zeros(channel_size), requires_grad=True)
17
+
18
+ def apply_gain_and_bias(self, normed_x):
19
+ """ Assumes input of size `[batch, chanel, *]`. """
20
+ return (self.gamma * normed_x.transpose(1, -1) + self.beta).transpose(1, -1)
21
+
22
+
23
+ class GlobLN(_LayerNorm):
24
+ """Global Layer Normalization (globLN)."""
25
+
26
+ def forward(self, x):
27
+ """Applies forward pass.
28
+
29
+ Works for any input size > 2D.
30
+
31
+ Args:
32
+ x (:class:`torch.Tensor`): Shape `[batch, chan, *]`
33
+
34
+ Returns:
35
+ :class:`torch.Tensor`: gLN_x `[batch, chan, *]`
36
+ """
37
+ dims = list(range(1, len(x.shape)))
38
+ mean = x.mean(dim=dims, keepdim=True)
39
+ var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True)
40
+ return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
41
+
42
+
43
+ class ChanLN(_LayerNorm):
44
+ """Channel-wise Layer Normalization (chanLN)."""
45
+
46
+ def forward(self, x):
47
+ """Applies forward pass.
48
+
49
+ Works for any input size > 2D.
50
+
51
+ Args:
52
+ x (:class:`torch.Tensor`): `[batch, chan, *]`
53
+
54
+ Returns:
55
+ :class:`torch.Tensor`: chanLN_x `[batch, chan, *]`
56
+ """
57
+ mean = torch.mean(x, dim=1, keepdim=True)
58
+ var = torch.var(x, dim=1, keepdim=True, unbiased=False)
59
+ return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
60
+
61
+
62
+ class CumLN(_LayerNorm):
63
+ """Cumulative Global layer normalization(cumLN)."""
64
+
65
+ def forward(self, x):
66
+ """
67
+
68
+ Args:
69
+ x (:class:`torch.Tensor`): Shape `[batch, channels, length]`
70
+ Returns:
71
+ :class:`torch.Tensor`: cumLN_x `[batch, channels, length]`
72
+ """
73
+ batch, chan, spec_len = x.size()
74
+ cum_sum = torch.cumsum(x.sum(1, keepdim=True), dim=-1)
75
+ cum_pow_sum = torch.cumsum(x.pow(2).sum(1, keepdim=True), dim=-1)
76
+ cnt = torch.arange(start=chan, end=chan * (spec_len + 1), step=chan, dtype=x.dtype).view(
77
+ 1, 1, -1
78
+ )
79
+ cum_mean = cum_sum / cnt
80
+ cum_var = cum_pow_sum - cum_mean.pow(2)
81
+ return self.apply_gain_and_bias((x - cum_mean) / (cum_var + EPS).sqrt())
82
+
83
+
84
+ class FeatsGlobLN(_LayerNorm):
85
+ """feature-wise global Layer Normalization (FeatsGlobLN).
86
+ Applies normalization over frames for each channel."""
87
+
88
+ def forward(self, x):
89
+ """Applies forward pass.
90
+
91
+ Works for any input size > 2D.
92
+
93
+ Args:
94
+ x (:class:`torch.Tensor`): `[batch, chan, time]`
95
+
96
+ Returns:
97
+ :class:`torch.Tensor`: chanLN_x `[batch, chan, time]`
98
+ """
99
+
100
+ stop = len(x.size())
101
+ dims = list(range(2, stop))
102
+
103
+ mean = torch.mean(x, dim=dims, keepdim=True)
104
+ var = torch.var(x, dim=dims, keepdim=True, unbiased=False)
105
+ return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
106
+
107
+
108
+ class BatchNorm(_BatchNorm):
109
+ """Wrapper class for pytorch BatchNorm1D and BatchNorm2D"""
110
+
111
+ def _check_input_dim(self, input):
112
+ if input.dim() < 2 or input.dim() > 4:
113
+ raise ValueError("expected 4D or 3D input (got {}D input)".format(input.dim()))
114
+
115
+
116
+ # Aliases.
117
+ gLN = GlobLN
118
+ fgLN = FeatsGlobLN
119
+ cLN = ChanLN
120
+ cgLN = CumLN
121
+ bN = BatchNorm
122
+
123
+
124
+ def register_norm(custom_norm):
125
+ """Register a custom norm, gettable with `norms.get`.
126
+
127
+ Args:
128
+ custom_norm: Custom norm to register.
129
+
130
+ """
131
+ if custom_norm.__name__ in globals().keys() or custom_norm.__name__.lower() in globals().keys():
132
+ raise ValueError(f"Norm {custom_norm.__name__} already exists. Choose another name.")
133
+ globals().update({custom_norm.__name__: custom_norm})
134
+
135
+
136
+ def get(identifier):
137
+ """Returns a norm class from a string. Returns its input if it
138
+ is callable (already a :class:`._LayerNorm` for example).
139
+
140
+ Args:
141
+ identifier (str or Callable or None): the norm identifier.
142
+
143
+ Returns:
144
+ :class:`._LayerNorm` or None
145
+ """
146
+ if identifier is None:
147
+ return None
148
+ elif callable(identifier):
149
+ return identifier
150
+ elif isinstance(identifier, str):
151
+ cls = globals().get(identifier)
152
+ if cls is None:
153
+ raise ValueError("Could not interpret normalization identifier: " + str(identifier))
154
+ return cls
155
+ else:
156
+ raise ValueError("Could not interpret normalization identifier: " + str(identifier))
DPTNet_eval/asteroid_test/models/__init__.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Models
2
+ # from .conv_tasnet import ConvTasNet
3
+ # from .dccrnet import DCCRNet
4
+ # from .dcunet import DCUNet
5
+ # from .dprnn_tasnet import DPRNNTasNet
6
+ # from .sudormrf import SuDORMRFImprovedNet, SuDORMRFNet
7
+ from .dptnet import DPTNet
8
+ # from .lstm_tasnet import LSTMTasNet
9
+ # from .demask import DeMask
10
+
11
+ # Sharing-related
12
+ # from .publisher import save_publishable, upload_publishable
13
+
14
+ __all__ = [
15
+ # "ConvTasNet",
16
+ # "DPRNNTasNet",
17
+ # "SuDORMRFImprovedNet",
18
+ # "SuDORMRFNet",
19
+ "DPTNet",
20
+ # "LSTMTasNet",
21
+ # "DeMask",
22
+ # "DCUNet",
23
+ # "DCCRNet",
24
+ # "save_publishable",
25
+ # "upload_publishable",
26
+ ]
27
+
28
+
29
+ def register_model(custom_model):
30
+ """Register a custom model, gettable with `models.get`.
31
+
32
+ Args:
33
+ custom_model: Custom model to register.
34
+
35
+ """
36
+ if (
37
+ custom_model.__name__ in globals().keys()
38
+ or custom_model.__name__.lower() in globals().keys()
39
+ ):
40
+ raise ValueError(f"Model {custom_model.__name__} already exists. Choose another name.")
41
+ globals().update({custom_model.__name__: custom_model})
42
+
43
+
44
+ def get(identifier):
45
+ """Returns an model class from a string (case-insensitive).
46
+
47
+ Args:
48
+ identifier (str): the model name.
49
+
50
+ Returns:
51
+ :class:`torch.nn.Module`
52
+ """
53
+ if isinstance(identifier, str):
54
+ to_get = {k.lower(): v for k, v in globals().items()}
55
+ cls = to_get.get(identifier.lower())
56
+ if cls is None:
57
+ raise ValueError(f"Could not interpret model name : {str(identifier)}")
58
+ return cls
59
+ raise ValueError(f"Could not interpret model name : {str(identifier)}")
DPTNet_eval/asteroid_test/models/base_models.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+
8
+ from ..masknn import activations
9
+ from ..utils.torch_utils import pad_x_to_y
10
+
11
+
12
+ def _unsqueeze_to_3d(x):
13
+ if x.ndim == 1:
14
+ return x.reshape(1, 1, -1)
15
+ elif x.ndim == 2:
16
+ return x.unsqueeze(1)
17
+ else:
18
+ return x
19
+
20
+
21
+ class BaseModel(nn.Module):
22
+ def __init__(self):
23
+ print("initialize BaseModel")
24
+ super().__init__()
25
+
26
+ def forward(self, *args, **kwargs):
27
+ raise NotImplementedError
28
+
29
+ @torch.no_grad()
30
+ def separate(self, wav, output_dir=None, force_overwrite=False, **kwargs):
31
+ """Infer separated sources from input waveforms.
32
+ Also supports filenames.
33
+
34
+ Args:
35
+ wav (Union[torch.Tensor, numpy.ndarray, str]): waveform array/tensor.
36
+ Shape: 1D, 2D or 3D tensor, time last.
37
+ output_dir (str): path to save all the wav files. If None,
38
+ estimated sources will be saved next to the original ones.
39
+ force_overwrite (bool): whether to overwrite existing files.
40
+ **kwargs: keyword arguments to be passed to `_separate`.
41
+
42
+ Returns:
43
+ Union[torch.Tensor, numpy.ndarray, None], the estimated sources.
44
+ (batch, n_src, time) or (n_src, time) w/o batch dim.
45
+
46
+ .. note::
47
+ By default, `separate` calls `_separate` which calls `forward`.
48
+ For models whose `forward` doesn't return waveform tensors,
49
+ overwrite `_separate` to return waveform tensors.
50
+ """
51
+ if isinstance(wav, str):
52
+ self.file_separate(
53
+ wav, output_dir=output_dir, force_overwrite=force_overwrite, **kwargs
54
+ )
55
+ elif isinstance(wav, np.ndarray):
56
+ print("is ndarray")
57
+ # import pdb ; pdb.set_trace()
58
+ return self.numpy_separate(wav, **kwargs)
59
+ elif isinstance(wav, torch.Tensor):
60
+ print("is torch.Tensor")
61
+ return self.torch_separate(wav, **kwargs)
62
+ else:
63
+ raise ValueError(
64
+ f"Only support filenames, numpy arrays and torch tensors, received {type(wav)}"
65
+ )
66
+
67
+ def torch_separate(self, wav: torch.Tensor, **kwargs) -> torch.Tensor:
68
+ """ Core logic of `separate`."""
69
+ # Handle device placement
70
+ input_device = wav.device
71
+ model_device = next(self.parameters()).device
72
+ wav = wav.to(model_device)
73
+ # Forward
74
+ out_wavs = self._separate(wav, **kwargs)
75
+
76
+ # FIXME: for now this is the best we can do.
77
+ out_wavs *= wav.abs().sum() / (out_wavs.abs().sum())
78
+
79
+ # Back to input device (and numpy if necessary)
80
+ out_wavs = out_wavs.to(input_device)
81
+ return out_wavs
82
+
83
+ def numpy_separate(self, wav: np.ndarray, **kwargs) -> np.ndarray:
84
+ """ Numpy interface to `separate`."""
85
+ wav = torch.from_numpy(wav)
86
+ out_wav = self.torch_separate(wav, **kwargs)
87
+ out_wav = out_wav.data.numpy()
88
+ return out_wav
89
+
90
+ def file_separate(
91
+ self, filename: str, output_dir=None, force_overwrite=False, **kwargs
92
+ ) -> None:
93
+ """ Filename interface to `separate`."""
94
+ import soundfile as sf
95
+
96
+ wav, fs = sf.read(filename, dtype="float32", always_2d=True)
97
+ # FIXME: support only single-channel files for now.
98
+ to_save = self.numpy_separate(wav[:, 0], **kwargs)
99
+
100
+ # Save wav files to filename_est1.wav etc...
101
+ for src_idx, est_src in enumerate(to_save):
102
+ base = ".".join(filename.split(".")[:-1])
103
+ save_name = base + "_est{}.".format(src_idx + 1) + filename.split(".")[-1]
104
+ if os.path.isfile(save_name) and not force_overwrite:
105
+ warnings.warn(
106
+ f"File {save_name} already exists, pass `force_overwrite=True` to overwrite it",
107
+ UserWarning,
108
+ )
109
+ return
110
+ if output_dir is not None:
111
+ save_name = os.path.join(output_dir, save_name.split("/")[-1])
112
+ sf.write(save_name, est_src, fs)
113
+
114
+ def _separate(self, wav, *args, **kwargs):
115
+ """Hidden separation method
116
+
117
+ Args:
118
+ wav (Union[torch.Tensor, numpy.ndarray, str]): waveform array/tensor.
119
+ Shape: 1D, 2D or 3D tensor, time last.
120
+
121
+ Returns:
122
+ The output of self(wav, *args, **kwargs).
123
+ """
124
+ return self(wav, *args, **kwargs)
125
+
126
+ @classmethod
127
+ def from_pretrained(cls, pretrained_model_conf_or_path, *args, **kwargs):
128
+ """Instantiate separation model from a model config (file or dict).
129
+
130
+ Args:
131
+ pretrained_model_conf_or_path (Union[dict, str]): model conf as
132
+ returned by `serialize`, or path to it. Need to contain
133
+ `model_args` and `state_dict` keys.
134
+ *args: Positional arguments to be passed to the model.
135
+ **kwargs: Keyword arguments to be passed to the model.
136
+ They overwrite the ones in the model package.
137
+
138
+ Returns:
139
+ nn.Module corresponding to the pretrained model conf/URL.
140
+
141
+ Raises:
142
+ ValueError if the input config file doesn't contain the keys
143
+ `model_name`, `model_args` or `state_dict`.
144
+ """
145
+ from . import get # Avoid circular imports
146
+
147
+ if isinstance(pretrained_model_conf_or_path, str):
148
+ # cached_model = self.cached_download(pretrained_model_conf_or_path)
149
+ if os.path.isfile(pretrained_model_conf_or_path):
150
+ cached_model = pretrained_model_conf_or_path
151
+ else:
152
+ raise ValueError(
153
+ "Model {} is not a file or doesn't exist.".format(pretrained_model_conf_or_path)
154
+ )
155
+
156
+ conf = torch.load(cached_model, map_location="cpu")
157
+ else:
158
+ conf = pretrained_model_conf_or_path
159
+
160
+ if "model_name" not in conf.keys():
161
+ raise ValueError(
162
+ "Expected config dictionary to have field "
163
+ "model_name`. Found only: {}".format(conf.keys())
164
+ )
165
+ if "state_dict" not in conf.keys():
166
+ raise ValueError(
167
+ "Expected config dictionary to have field "
168
+ "state_dict`. Found only: {}".format(conf.keys())
169
+ )
170
+ if "model_args" not in conf.keys():
171
+ raise ValueError(
172
+ "Expected config dictionary to have field "
173
+ "model_args`. Found only: {}".format(conf.keys())
174
+ )
175
+ conf["model_args"].update(kwargs) # kwargs overwrite config.
176
+ # Attempt to find the model and instantiate it.
177
+ try:
178
+ model_class = get(conf["model_name"])
179
+ except ValueError: # Couldn't get the model, maybe custom.
180
+ model = cls(*args, **conf["model_args"]) # Child class.
181
+ else:
182
+ model = model_class(*args, **conf["model_args"])
183
+ model.load_state_dict(conf["state_dict"])
184
+ return model
185
+
186
+ def serialize(self):
187
+ """Serialize model and output dictionary.
188
+
189
+ Returns:
190
+ dict, serialized model with keys `model_args` and `state_dict`.
191
+ """
192
+ import pytorch_lightning as pl # Not used in torch.hub
193
+
194
+ from .. import __version__ as asteroid_version # Avoid circular imports
195
+
196
+ model_conf = dict(
197
+ model_name=self.__class__.__name__,
198
+ state_dict=self.get_state_dict(),
199
+ model_args=self.get_model_args(),
200
+ )
201
+ # Additional infos
202
+ infos = dict()
203
+ infos["software_versions"] = dict(
204
+ torch_version=torch.__version__,
205
+ pytorch_lightning_version=pl.__version__,
206
+ asteroid_version=asteroid_version,
207
+ )
208
+ model_conf["infos"] = infos
209
+ return model_conf
210
+
211
+ def get_state_dict(self):
212
+ """ In case the state dict needs to be modified before sharing the model."""
213
+ return self.state_dict()
214
+
215
+ def get_model_args(self):
216
+ raise NotImplementedError
217
+
218
+ def cached_download(self, filename_or_url):
219
+ if os.path.isfile(filename_or_url):
220
+ print("is file")
221
+ return filename_or_url
222
+ else:
223
+ print("Model {} is not a file or doesn't exist.".format(filename_or_url))
224
+
225
+
226
+ class BaseEncoderMaskerDecoder(BaseModel):
227
+ """Base class for encoder-masker-decoder separation models.
228
+
229
+ Args:
230
+ encoder (Encoder): Encoder instance.
231
+ masker (nn.Module): masker network.
232
+ decoder (Decoder): Decoder instance.
233
+ encoder_activation (Optional[str], optional): Activation to apply after encoder.
234
+ See ``asteroid.masknn.activations`` for valid values.
235
+ """
236
+
237
+ def __init__(self, encoder, masker, decoder, encoder_activation=None):
238
+ super().__init__()
239
+ self.encoder = encoder
240
+ self.masker = masker
241
+ self.decoder = decoder
242
+
243
+ self.encoder_activation = encoder_activation
244
+ self.enc_activation = activations.get(encoder_activation or "linear")()
245
+
246
+ def forward(self, wav):
247
+ """Enc/Mask/Dec model forward
248
+
249
+ Args:
250
+ wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
251
+
252
+ Returns:
253
+ torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
254
+ """
255
+ # Handle 1D, 2D or n-D inputs
256
+ was_one_d = wav.ndim == 1
257
+ # Reshape to (batch, n_mix, time)
258
+ wav = _unsqueeze_to_3d(wav)
259
+
260
+ # Real forward
261
+ tf_rep = self.encoder(wav)
262
+ tf_rep = self.postprocess_encoded(tf_rep)
263
+ tf_rep = self.enc_activation(tf_rep)
264
+
265
+ est_masks = self.masker(tf_rep)
266
+ est_masks = self.postprocess_masks(est_masks)
267
+
268
+ masked_tf_rep = est_masks * tf_rep.unsqueeze(1)
269
+ masked_tf_rep = self.postprocess_masked(masked_tf_rep)
270
+
271
+ decoded = self.decoder(masked_tf_rep)
272
+ decoded = self.postprocess_decoded(decoded)
273
+
274
+ reconstructed = pad_x_to_y(decoded, wav)
275
+ if was_one_d:
276
+ return reconstructed.squeeze(0)
277
+ else:
278
+ return reconstructed
279
+
280
+ def postprocess_encoded(self, tf_rep):
281
+ """Hook to perform transformations on the encoded, time-frequency domain
282
+ representation (output of the encoder) before encoder activation is applied.
283
+
284
+ Args:
285
+ tf_rep (Tensor of shape (batch, freq, time)):
286
+ Output of the encoder, before encoder activation is applied.
287
+
288
+ Return:
289
+ Transformed `tf_rep`
290
+ """
291
+ return tf_rep
292
+
293
+ def postprocess_masks(self, masks):
294
+ """Hook to perform transformations on the masks (output of the masker) before
295
+ masks are applied.
296
+
297
+ Args:
298
+ masks (Tensor of shape (batch, n_src, freq, time)):
299
+ Output of the masker
300
+
301
+ Return:
302
+ Transformed `masks`
303
+ """
304
+ return masks
305
+
306
+ def postprocess_masked(self, masked_tf_rep):
307
+ """Hook to perform transformations on the masked time-frequency domain
308
+ representation (result of masking in the time-frequency domain) before decoding.
309
+
310
+ Args:
311
+ masked_tf_rep (Tensor of shape (batch, n_src, freq, time)):
312
+ Masked time-frequency representation, before decoding.
313
+
314
+ Return:
315
+ Transformed `masked_tf_rep`
316
+ """
317
+ return masked_tf_rep
318
+
319
+ def postprocess_decoded(self, decoded):
320
+ """Hook to perform transformations on the decoded, time domain representation
321
+ (output of the decoder) before original shape reconstruction.
322
+
323
+ Args:
324
+ decoded (Tensor of shape (batch, n_src, time)):
325
+ Output of the decoder, before original shape reconstruction.
326
+
327
+ Return:
328
+ Transformed `decoded`
329
+ """
330
+ return decoded
331
+
332
+ def get_model_args(self):
333
+ """ Arguments needed to re-instantiate the model. """
334
+ fb_config = self.encoder.filterbank.get_config()
335
+ masknet_config = self.masker.get_config()
336
+ # Assert both dict are disjoint
337
+ if not all(k not in fb_config for k in masknet_config):
338
+ raise AssertionError(
339
+ "Filterbank and Mask network config share" "common keys. Merging them is not safe."
340
+ )
341
+ # Merge all args under model_args.
342
+ model_args = {
343
+ **fb_config,
344
+ **masknet_config,
345
+ "encoder_activation": self.encoder_activation,
346
+ }
347
+ return model_args
348
+
349
+
350
+ # Backwards compatibility
351
+ BaseTasNet = BaseEncoderMaskerDecoder
DPTNet_eval/asteroid_test/models/dptnet.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..filterbanks import make_enc_dec
2
+ from ..masknn import DPTransformer
3
+ from .base_models import BaseEncoderMaskerDecoder
4
+
5
+
6
+ class DPTNet(BaseEncoderMaskerDecoder):
7
+ """DPTNet separation model, as described in [1].
8
+
9
+ Args:
10
+ n_src (int): Number of masks to estimate.
11
+ out_chan (int or None): Number of bins in the estimated masks.
12
+ Defaults to `in_chan`.
13
+ bn_chan (int): Number of channels after the bottleneck.
14
+ Defaults to 128.
15
+ hid_size (int): Number of neurons in the RNNs cell state.
16
+ Defaults to 128.
17
+ chunk_size (int): window size of overlap and add processing.
18
+ Defaults to 100.
19
+ hop_size (int or None): hop size (stride) of overlap and add processing.
20
+ Default to `chunk_size // 2` (50% overlap).
21
+ n_repeats (int): Number of repeats. Defaults to 6.
22
+ norm_type (str, optional): Type of normalization to use. To choose from
23
+
24
+ - ``'gLN'``: global Layernorm
25
+ - ``'cLN'``: channelwise Layernorm
26
+ mask_act (str, optional): Which non-linear function to generate mask.
27
+ bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN
28
+ (Intra-Chunk is always bidirectional).
29
+ rnn_type (str, optional): Type of RNN used. Choose between ``'RNN'``,
30
+ ``'LSTM'`` and ``'GRU'``.
31
+ num_layers (int, optional): Number of layers in each RNN.
32
+ dropout (float, optional): Dropout ratio, must be in [0,1].
33
+ in_chan (int, optional): Number of input channels, should be equal to
34
+ n_filters.
35
+ fb_name (str, className): Filterbank family from which to make encoder
36
+ and decoder. To choose among [``'free'``, ``'analytic_free'``,
37
+ ``'param_sinc'``, ``'stft'``].
38
+ n_filters (int): Number of filters / Input dimension of the masker net.
39
+ kernel_size (int): Length of the filters.
40
+ stride (int, optional): Stride of the convolution.
41
+ If None (default), set to ``kernel_size // 2``.
42
+ **fb_kwargs (dict): Additional kwards to pass to the filterbank
43
+ creation.
44
+
45
+ References:
46
+ [1]: Jingjing Chen et al. "Dual-Path Transformer Network: Direct
47
+ Context-Aware Modeling for End-to-End Monaural Speech Separation"
48
+ Interspeech 2020.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ n_src,
54
+ ff_hid=256,
55
+ chunk_size=100,
56
+ hop_size=None,
57
+ n_repeats=6,
58
+ norm_type="gLN",
59
+ ff_activation="relu",
60
+ encoder_activation="relu",
61
+ mask_act="relu",
62
+ bidirectional=True,
63
+ dropout=0,
64
+ in_chan=None,
65
+ fb_name="free",
66
+ kernel_size=16,
67
+ n_filters=64,
68
+ stride=8,
69
+ **fb_kwargs,
70
+ ):
71
+ encoder, decoder = make_enc_dec(
72
+ fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, **fb_kwargs
73
+ )
74
+ n_feats = encoder.n_feats_out
75
+ if in_chan is not None:
76
+ assert in_chan == n_feats, (
77
+ "Number of filterbank output channels"
78
+ " and number of input channels should "
79
+ "be the same. Received "
80
+ f"{n_feats} and {in_chan}"
81
+ )
82
+ # Update in_chan
83
+ masker = DPTransformer(
84
+ n_feats,
85
+ n_src,
86
+ ff_hid=ff_hid,
87
+ ff_activation=ff_activation,
88
+ chunk_size=chunk_size,
89
+ hop_size=hop_size,
90
+ n_repeats=n_repeats,
91
+ norm_type=norm_type,
92
+ mask_act=mask_act,
93
+ bidirectional=bidirectional,
94
+ dropout=dropout,
95
+ )
96
+ super().__init__(encoder, masker, decoder, encoder_activation=encoder_activation)
DPTNet_eval/asteroid_test/utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .torch_utils import tensors_to_device, to_cuda
2
+
3
+ # The functions above were all in asteroid/utils.py before refactoring into
4
+ # asteroid/utils/*_utils.py files. They are imported for backward compatibility.
5
+
6
+ __all__ = [
7
+ "tensors_to_device",
8
+ "to_cuda",
9
+ ]
DPTNet_eval/asteroid_test/utils/torch_utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from collections import OrderedDict
4
+
5
+
6
+ def to_cuda(tensors): # pragma: no cover (No CUDA on travis)
7
+ """Transfer tensor, dict or list of tensors to GPU.
8
+
9
+ Args:
10
+ tensors (:class:`torch.Tensor`, list or dict): May be a single, a
11
+ list or a dictionary of tensors.
12
+
13
+ Returns:
14
+ :class:`torch.Tensor`:
15
+ Same as input but transferred to cuda. Goes through lists and dicts
16
+ and transfers the torch.Tensor to cuda. Leaves the rest untouched.
17
+ """
18
+ if isinstance(tensors, torch.Tensor):
19
+ return tensors.cuda()
20
+ if isinstance(tensors, list):
21
+ return [to_cuda(tens) for tens in tensors]
22
+ if isinstance(tensors, dict):
23
+ for key in tensors.keys():
24
+ tensors[key] = to_cuda(tensors[key])
25
+ return tensors
26
+ raise TypeError(
27
+ "tensors must be a tensor or a list or dict of tensors. "
28
+ " Got tensors of type {}".format(type(tensors))
29
+ )
30
+
31
+
32
+ def tensors_to_device(tensors, device):
33
+ """Transfer tensor, dict or list of tensors to device.
34
+
35
+ Args:
36
+ tensors (:class:`torch.Tensor`): May be a single, a list or a
37
+ dictionary of tensors.
38
+ device (:class: `torch.device`): the device where to place the tensors.
39
+
40
+ Returns:
41
+ Union [:class:`torch.Tensor`, list, tuple, dict]:
42
+ Same as input but transferred to device.
43
+ Goes through lists and dicts and transfers the torch.Tensor to
44
+ device. Leaves the rest untouched.
45
+ """
46
+ if isinstance(tensors, torch.Tensor):
47
+ return tensors.to(device)
48
+ elif isinstance(tensors, (list, tuple)):
49
+ return [tensors_to_device(tens, device) for tens in tensors]
50
+ elif isinstance(tensors, dict):
51
+ for key in tensors.keys():
52
+ tensors[key] = tensors_to_device(tensors[key], device)
53
+ return tensors
54
+ else:
55
+ return tensors
56
+
57
+
58
+ def pad_x_to_y(x, y, axis=-1):
59
+ """Pad first argument to have same size as second argument
60
+
61
+ Args:
62
+ x (torch.Tensor): Tensor to be padded.
63
+ y (torch.Tensor): Tensor to pad x to.
64
+ axis (int): Axis to pad on.
65
+
66
+ Returns:
67
+ torch.Tensor, x padded to match y's shape.
68
+ """
69
+ if axis != -1:
70
+ raise NotImplementedError
71
+ inp_len = y.size(axis)
72
+ output_len = x.size(axis)
73
+ return nn.functional.pad(x, [0, inp_len - output_len])
74
+
75
+
76
+ def load_state_dict_in(state_dict, model):
77
+ """Strictly loads state_dict in model, or the next submodel.
78
+ Useful to load standalone model after training it with System.
79
+
80
+ Args:
81
+ state_dict (OrderedDict): the state_dict to load.
82
+ model (torch.nn.Module): the model to load it into
83
+
84
+ Returns:
85
+ torch.nn.Module: model with loaded weights.
86
+
87
+ # .. note:: Keys in a state_dict look like object1.object2.layer_name.weight.etc
88
+ We first try to load the model in the classic way.
89
+ If this fail we removes the first left part of the key to obtain
90
+ object2.layer_name.weight.etc.
91
+ Blindly loading with strictly=False should be done with some logging
92
+ of the missing keys in the state_dict and the model.
93
+
94
+ """
95
+ try:
96
+ # This can fail if the model was included into a bigger nn.Module
97
+ # object. For example, into System.
98
+ model.load_state_dict(state_dict, strict=True)
99
+ except RuntimeError:
100
+ # keys look like object1.object2.layer_name.weight.etc
101
+ # The following will remove the first left part of the key to obtain
102
+ # object2.layer_name.weight.etc.
103
+ # Blindly loading with strictly=False should be done with some
104
+ # new_state_dict of the missing keys in the state_dict and the model.
105
+ new_state_dict = OrderedDict()
106
+ for k, v in state_dict.items():
107
+ new_k = k[k.find(".") + 1 :]
108
+ new_state_dict[new_k] = v
109
+ model.load_state_dict(new_state_dict, strict=True)
110
+ return model
111
+
112
+
113
+ def are_models_equal(model1, model2):
114
+ """Check for weights equality between models.
115
+
116
+ Args:
117
+ model1 (nn.Module): model instance to be compared.
118
+ model2 (nn.Module): second model instance to be compared.
119
+
120
+ Returns:
121
+ bool: Whether all model weights are equal.
122
+ """
123
+ for p1, p2 in zip(model1.parameters(), model2.parameters()):
124
+ if p1.data.ne(p2.data).sum() > 0:
125
+ return False
126
+ return True