Upload 16 files
Browse files- DPTNet_eval/DPTNet_quant_sep.py +108 -0
- DPTNet_eval/asteroid_test/__init__.py +19 -0
- DPTNet_eval/asteroid_test/dsp/__init__.py +5 -0
- DPTNet_eval/asteroid_test/dsp/overlap_add.py +317 -0
- DPTNet_eval/asteroid_test/filterbanks/__init__.py +107 -0
- DPTNet_eval/asteroid_test/filterbanks/enc_dec.py +267 -0
- DPTNet_eval/asteroid_test/filterbanks/free_fb.py +33 -0
- DPTNet_eval/asteroid_test/masknn/__init__.py +12 -0
- DPTNet_eval/asteroid_test/masknn/activations.py +82 -0
- DPTNet_eval/asteroid_test/masknn/attention.py +271 -0
- DPTNet_eval/asteroid_test/masknn/norms.py +156 -0
- DPTNet_eval/asteroid_test/models/__init__.py +59 -0
- DPTNet_eval/asteroid_test/models/base_models.py +351 -0
- DPTNet_eval/asteroid_test/models/dptnet.py +96 -0
- DPTNet_eval/asteroid_test/utils/__init__.py +9 -0
- DPTNet_eval/asteroid_test/utils/torch_utils.py +126 -0
DPTNet_eval/DPTNet_quant_sep.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DPTNet_quant_sep.py
|
2 |
+
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import torchaudio
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
from . import asteroid_test
|
9 |
+
|
10 |
+
torchaudio.set_audio_backend("sox_io")
|
11 |
+
|
12 |
+
def get_conf():
|
13 |
+
conf_filterbank = {
|
14 |
+
'n_filters': 64,
|
15 |
+
'kernel_size': 16,
|
16 |
+
'stride': 8
|
17 |
+
}
|
18 |
+
|
19 |
+
conf_masknet = {
|
20 |
+
'in_chan': 64,
|
21 |
+
'n_src': 2,
|
22 |
+
'out_chan': 64,
|
23 |
+
'ff_hid': 256,
|
24 |
+
'ff_activation': "relu",
|
25 |
+
'norm_type': "gLN",
|
26 |
+
'chunk_size': 100,
|
27 |
+
'hop_size': 50,
|
28 |
+
'n_repeats': 2,
|
29 |
+
'mask_act': 'sigmoid',
|
30 |
+
'bidirectional': True,
|
31 |
+
'dropout': 0
|
32 |
+
}
|
33 |
+
return conf_filterbank, conf_masknet
|
34 |
+
|
35 |
+
|
36 |
+
def load_dpt_model():
|
37 |
+
print('Load Separation Model...')
|
38 |
+
|
39 |
+
# 從環境變數取得 Hugging Face Token
|
40 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
41 |
+
if not HF_TOKEN:
|
42 |
+
raise EnvironmentError("環境變數 HF_TOKEN 未設定!請先執行 export HF_TOKEN=xxx")
|
43 |
+
|
44 |
+
# 從 Hugging Face Hub 下載模型權重
|
45 |
+
model_path = hf_hub_download(
|
46 |
+
repo_id="DeepLearning101/speech-separation", # ← 替換成你的 repo 名稱
|
47 |
+
filename="train_dptnet_aishell_partOverlap_B2_300epoch_quan-int8.p",
|
48 |
+
token=HF_TOKEN
|
49 |
+
)
|
50 |
+
|
51 |
+
# 取得模型參數
|
52 |
+
conf_filterbank, conf_masknet = get_conf()
|
53 |
+
|
54 |
+
# 建立模型架構
|
55 |
+
model_class = getattr(asteroid_test, "DPTNet")
|
56 |
+
model = model_class(**conf_filterbank, **conf_masknet)
|
57 |
+
|
58 |
+
# 套用量化設定
|
59 |
+
model = torch.quantization.quantize_dynamic(
|
60 |
+
model,
|
61 |
+
{torch.nn.LSTM, torch.nn.Linear},
|
62 |
+
dtype=torch.qint8
|
63 |
+
)
|
64 |
+
|
65 |
+
# 載入權重(忽略不匹配的 keys)
|
66 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
67 |
+
model_state_dict = model.state_dict()
|
68 |
+
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
|
69 |
+
model.load_state_dict(filtered_state_dict, strict=False)
|
70 |
+
model.eval()
|
71 |
+
|
72 |
+
return model
|
73 |
+
|
74 |
+
|
75 |
+
def dpt_sep_process(wav_path, model=None, outfilename=None):
|
76 |
+
if model is None:
|
77 |
+
model = load_dpt_model()
|
78 |
+
|
79 |
+
x, sr = torchaudio.load(wav_path)
|
80 |
+
x = x.cpu()
|
81 |
+
|
82 |
+
with torch.no_grad():
|
83 |
+
est_sources = model(x) # shape: (1, 2, T)
|
84 |
+
|
85 |
+
est_sources = est_sources.squeeze(0) # shape: (2, T)
|
86 |
+
sep_1, sep_2 = est_sources # 拆成兩個 (T,) 的 tensor
|
87 |
+
|
88 |
+
# 正規化
|
89 |
+
max_abs = x[0].abs().max().item()
|
90 |
+
sep_1 = sep_1 * max_abs / sep_1.abs().max().item()
|
91 |
+
sep_2 = sep_2 * max_abs / sep_2.abs().max().item()
|
92 |
+
|
93 |
+
# 增加 channel 維度,變為 (1, T)
|
94 |
+
sep_1 = sep_1.unsqueeze(0)
|
95 |
+
sep_2 = sep_2.unsqueeze(0)
|
96 |
+
|
97 |
+
# 儲存結果
|
98 |
+
if outfilename is not None:
|
99 |
+
torchaudio.save(outfilename.replace('.wav', '_sep1.wav'), sep_1, sr)
|
100 |
+
torchaudio.save(outfilename.replace('.wav', '_sep2.wav'), sep_2, sr)
|
101 |
+
torchaudio.save(outfilename.replace('.wav', '_mix.wav'), x, sr)
|
102 |
+
else:
|
103 |
+
torchaudio.save(wav_path.replace('.wav', '_sep1.wav'), sep_1, sr)
|
104 |
+
torchaudio.save(wav_path.replace('.wav', '_sep2.wav'), sep_2, sr)
|
105 |
+
|
106 |
+
|
107 |
+
if __name__ == '__main__':
|
108 |
+
print("This module should be used via Flask or Gradio.")
|
DPTNet_eval/asteroid_test/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
|
3 |
+
from .models import DPTNet
|
4 |
+
from .utils import torch_utils # noqa
|
5 |
+
|
6 |
+
project_root = str(pathlib.Path(__file__).expanduser().absolute().parent.parent)
|
7 |
+
__version__ = "0.3.4"
|
8 |
+
|
9 |
+
|
10 |
+
def show_available_models():
|
11 |
+
from .utils.hub_utils import MODELS_URLS_HASHTABLE
|
12 |
+
|
13 |
+
print(" \n".join(list(MODELS_URLS_HASHTABLE.keys())))
|
14 |
+
|
15 |
+
|
16 |
+
__all__ = [
|
17 |
+
"DPTNet",
|
18 |
+
"show_available_models",
|
19 |
+
]
|
DPTNet_eval/asteroid_test/dsp/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .overlap_add import DualPathProcessing
|
2 |
+
|
3 |
+
__all__ = [
|
4 |
+
"DualPathProcessing",
|
5 |
+
]
|
DPTNet_eval/asteroid_test/dsp/overlap_add.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from scipy.signal import get_window
|
3 |
+
# from asteroid_test.losses import PITLossWrapper
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
'''
|
7 |
+
class LambdaOverlapAdd(torch.nn.Module):
|
8 |
+
"""Overlap-add with lambda transform on segments.
|
9 |
+
|
10 |
+
Segment input signal, apply lambda function (a neural network for example)
|
11 |
+
and combine with OLA.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
nnet (callable): Function to apply to each segment.
|
15 |
+
n_src (int): Number of sources in the output of nnet.
|
16 |
+
window_size (int): Size of segmenting window.
|
17 |
+
hop_size (int): Segmentation hop size.
|
18 |
+
window (str): Name of the window (see scipy.signal.get_window) used
|
19 |
+
for the synthesis.
|
20 |
+
reorder_chunks (bool): Whether to reorder each consecutive segment.
|
21 |
+
This might be useful when `nnet` is permutation invariant, as
|
22 |
+
source assignements might change output channel from one segment
|
23 |
+
to the next (in classic speech separation for example).
|
24 |
+
Reordering is performed based on the correlation between
|
25 |
+
the overlapped part of consecutive segment.
|
26 |
+
|
27 |
+
Examples:
|
28 |
+
>>> from asteroid_test import ConvTasNet
|
29 |
+
>>> nnet = ConvTasNet(n_src=2)
|
30 |
+
>>> continuous_nnet = LambdaOverlapAdd(
|
31 |
+
>>> nnet=nnet,
|
32 |
+
>>> n_src=2,
|
33 |
+
>>> window_size=64000,
|
34 |
+
>>> hop_size=None,
|
35 |
+
>>> window="hanning",
|
36 |
+
>>> reorder_chunks=True,
|
37 |
+
>>> enable_grad=False,
|
38 |
+
>>> )
|
39 |
+
>>> wav = torch.randn(1, 1, 500000)
|
40 |
+
>>> out_wavs = continuous_nnet.forward(wav)
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
nnet,
|
46 |
+
n_src,
|
47 |
+
window_size,
|
48 |
+
hop_size=None,
|
49 |
+
window="hanning",
|
50 |
+
reorder_chunks=True,
|
51 |
+
enable_grad=False,
|
52 |
+
):
|
53 |
+
super().__init__()
|
54 |
+
assert window_size % 2 == 0, "Window size must be even"
|
55 |
+
|
56 |
+
self.nnet = nnet
|
57 |
+
self.window_size = window_size
|
58 |
+
self.hop_size = hop_size if hop_size is not None else window_size // 2
|
59 |
+
self.n_src = n_src
|
60 |
+
|
61 |
+
if window:
|
62 |
+
window = get_window(window, self.window_size).astype("float32")
|
63 |
+
window = torch.from_numpy(window)
|
64 |
+
self.use_window = True
|
65 |
+
else:
|
66 |
+
self.use_window = False
|
67 |
+
|
68 |
+
self.register_buffer("window", window)
|
69 |
+
self.reorder_chunks = reorder_chunks
|
70 |
+
self.enable_grad = enable_grad
|
71 |
+
|
72 |
+
def ola_forward(self, x):
|
73 |
+
"""Heart of the class: segment signal, apply func, combine with OLA."""
|
74 |
+
|
75 |
+
assert x.ndim == 3
|
76 |
+
|
77 |
+
batch, channels, n_frames = x.size()
|
78 |
+
# Overlap and add:
|
79 |
+
# [batch, chans, n_frames] -> [batch, chans, win_size, n_chunks]
|
80 |
+
unfolded = torch.nn.functional.unfold(
|
81 |
+
x.unsqueeze(-1),
|
82 |
+
kernel_size=(self.window_size, 1),
|
83 |
+
padding=(self.window_size, 0),
|
84 |
+
stride=(self.hop_size, 1),
|
85 |
+
)
|
86 |
+
|
87 |
+
out = []
|
88 |
+
n_chunks = unfolded.shape[-1]
|
89 |
+
for frame_idx in range(n_chunks): # for loop to spare memory
|
90 |
+
frame = self.nnet(unfolded[..., frame_idx])
|
91 |
+
# user must handle multichannel by reshaping to batch
|
92 |
+
if frame_idx == 0:
|
93 |
+
assert frame.ndim == 3, "nnet should return (batch, n_src, time)"
|
94 |
+
assert frame.shape[1] == self.n_src, "nnet should return (batch, n_src, time)"
|
95 |
+
frame = frame.reshape(batch * self.n_src, -1)
|
96 |
+
|
97 |
+
if frame_idx != 0 and self.reorder_chunks:
|
98 |
+
# we determine best perm based on xcorr with previous sources
|
99 |
+
frame = _reorder_sources(
|
100 |
+
frame, out[-1], self.n_src, self.window_size, self.hop_size
|
101 |
+
)
|
102 |
+
|
103 |
+
if self.use_window:
|
104 |
+
frame = frame * self.window
|
105 |
+
else:
|
106 |
+
frame = frame / (self.window_size / self.hop_size)
|
107 |
+
out.append(frame)
|
108 |
+
|
109 |
+
out = torch.stack(out).reshape(n_chunks, batch * self.n_src, self.window_size)
|
110 |
+
out = out.permute(1, 2, 0)
|
111 |
+
|
112 |
+
out = torch.nn.functional.fold(
|
113 |
+
out,
|
114 |
+
(n_frames, 1),
|
115 |
+
kernel_size=(self.window_size, 1),
|
116 |
+
padding=(self.window_size, 0),
|
117 |
+
stride=(self.hop_size, 1),
|
118 |
+
)
|
119 |
+
return out.squeeze(-1).reshape(batch, self.n_src, -1)
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
"""Forward module: segment signal, apply func, combine with OLA.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
x (:class:`torch.Tensor`): waveform signal of shape (batch, 1, time).
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
:class:`torch.Tensor`: The output of the lambda OLA.
|
129 |
+
"""
|
130 |
+
# Here we can do the reshaping
|
131 |
+
with torch.autograd.set_grad_enabled(self.enable_grad):
|
132 |
+
olad = self.ola_forward(x)
|
133 |
+
return olad
|
134 |
+
|
135 |
+
|
136 |
+
def _reorder_sources(
|
137 |
+
current: torch.FloatTensor,
|
138 |
+
previous: torch.FloatTensor,
|
139 |
+
n_src: int,
|
140 |
+
window_size: int,
|
141 |
+
hop_size: int,
|
142 |
+
):
|
143 |
+
"""
|
144 |
+
Reorder sources in current chunk to maximize correlation with previous chunk.
|
145 |
+
Used for Continuous Source Separation. Standard dsp correlation is used
|
146 |
+
for reordering.
|
147 |
+
|
148 |
+
|
149 |
+
Args:
|
150 |
+
current (:class:`torch.Tensor`): current chunk, tensor
|
151 |
+
of shape (batch, n_src, window_size)
|
152 |
+
previous (:class:`torch.Tensor`): previous chunk, tensor
|
153 |
+
of shape (batch, n_src, window_size)
|
154 |
+
n_src (:class:`int`): number of sources.
|
155 |
+
window_size (:class:`int`): window_size, equal to last dimension of
|
156 |
+
both current and previous.
|
157 |
+
hop_size (:class:`int`): hop_size between current and previous tensors.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
current:
|
161 |
+
|
162 |
+
"""
|
163 |
+
batch, frames = current.size()
|
164 |
+
current = current.reshape(-1, n_src, frames)
|
165 |
+
previous = previous.reshape(-1, n_src, frames)
|
166 |
+
|
167 |
+
overlap_f = window_size - hop_size
|
168 |
+
|
169 |
+
def reorder_func(x, y):
|
170 |
+
x = x[..., :overlap_f]
|
171 |
+
y = y[..., -overlap_f:]
|
172 |
+
# Mean normalization
|
173 |
+
x = x - x.mean(-1, keepdim=True)
|
174 |
+
y = y - y.mean(-1, keepdim=True)
|
175 |
+
# Negative mean Correlation
|
176 |
+
return -torch.sum(x.unsqueeze(1) * y.unsqueeze(2), dim=-1)
|
177 |
+
|
178 |
+
# We maximize correlation-like between previous and current.
|
179 |
+
pit = PITLossWrapper(reorder_func)
|
180 |
+
current = pit(current, previous, return_est=True)[1]
|
181 |
+
return current.reshape(batch, frames)
|
182 |
+
'''
|
183 |
+
|
184 |
+
|
185 |
+
class DualPathProcessing(nn.Module):
|
186 |
+
"""Perform Dual-Path processing via overlap-add as in DPRNN [1].
|
187 |
+
|
188 |
+
Args:
|
189 |
+
chunk_size (int): Size of segmenting window.
|
190 |
+
hop_size (int): segmentation hop size.
|
191 |
+
|
192 |
+
References:
|
193 |
+
[1] "Dual-path RNN: efficient long sequence modeling for
|
194 |
+
time-domain single-channel speech separation", Yi Luo, Zhuo Chen
|
195 |
+
and Takuya Yoshioka. https://arxiv.org/abs/1910.06379
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(self, chunk_size, hop_size):
|
199 |
+
super(DualPathProcessing, self).__init__()
|
200 |
+
self.chunk_size = chunk_size
|
201 |
+
self.hop_size = hop_size
|
202 |
+
self.n_orig_frames = None
|
203 |
+
|
204 |
+
def unfold(self, x):
|
205 |
+
"""Unfold the feature tensor from
|
206 |
+
|
207 |
+
(batch, channels, time) to (batch, channels, chunk_size, n_chunks).
|
208 |
+
|
209 |
+
Args:
|
210 |
+
x: (:class:`torch.Tensor`): feature tensor of shape (batch, channels, time).
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
x: (:class:`torch.Tensor`): spliced feature tensor of shape
|
214 |
+
(batch, channels, chunk_size, n_chunks).
|
215 |
+
|
216 |
+
"""
|
217 |
+
# x is (batch, chan, frames)
|
218 |
+
batch, chan, frames = x.size()
|
219 |
+
assert x.ndim == 3
|
220 |
+
self.n_orig_frames = x.shape[-1]
|
221 |
+
unfolded = torch.nn.functional.unfold(
|
222 |
+
x.unsqueeze(-1),
|
223 |
+
kernel_size=(self.chunk_size, 1),
|
224 |
+
padding=(self.chunk_size, 0),
|
225 |
+
stride=(self.hop_size, 1),
|
226 |
+
)
|
227 |
+
|
228 |
+
return unfolded.reshape(
|
229 |
+
batch, chan, self.chunk_size, -1
|
230 |
+
) # (batch, chan, chunk_size, n_chunks)
|
231 |
+
|
232 |
+
def fold(self, x, output_size=None):
|
233 |
+
"""Folds back the spliced feature tensor.
|
234 |
+
|
235 |
+
Input shape (batch, channels, chunk_size, n_chunks) to original shape
|
236 |
+
(batch, channels, time) using overlap-add.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
x: (:class:`torch.Tensor`): spliced feature tensor of shape
|
240 |
+
(batch, channels, chunk_size, n_chunks).
|
241 |
+
output_size: (int, optional): sequence length of original feature tensor.
|
242 |
+
If None, the original length cached by the previous call of `unfold`
|
243 |
+
will be used.
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
x: (:class:`torch.Tensor`): feature tensor of shape (batch, channels, time).
|
247 |
+
|
248 |
+
.. note:: `fold` caches the original length of the pr
|
249 |
+
|
250 |
+
"""
|
251 |
+
output_size = output_size if output_size is not None else self.n_orig_frames
|
252 |
+
# x is (batch, chan, chunk_size, n_chunks)
|
253 |
+
batch, chan, chunk_size, n_chunks = x.size()
|
254 |
+
to_unfold = x.reshape(batch, chan * self.chunk_size, n_chunks)
|
255 |
+
x = torch.nn.functional.fold(
|
256 |
+
to_unfold,
|
257 |
+
(output_size, 1),
|
258 |
+
kernel_size=(self.chunk_size, 1),
|
259 |
+
padding=(self.chunk_size, 0),
|
260 |
+
stride=(self.hop_size, 1),
|
261 |
+
)
|
262 |
+
|
263 |
+
x /= self.chunk_size / self.hop_size
|
264 |
+
|
265 |
+
return x.reshape(batch, chan, self.n_orig_frames)
|
266 |
+
|
267 |
+
@staticmethod
|
268 |
+
def intra_process(x, module):
|
269 |
+
"""Performs intra-chunk processing.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
x (:class:`torch.Tensor`): spliced feature tensor of shape
|
273 |
+
(batch, channels, chunk_size, n_chunks).
|
274 |
+
module (:class:`torch.nn.Module`): module one wish to apply to each chunk
|
275 |
+
of the spliced feature tensor.
|
276 |
+
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
x (:class:`torch.Tensor`): processed spliced feature tensor of shape
|
280 |
+
(batch, channels, chunk_size, n_chunks).
|
281 |
+
|
282 |
+
.. note:: the module should have the channel first convention and accept
|
283 |
+
a 3D tensor of shape (batch, channels, time).
|
284 |
+
"""
|
285 |
+
|
286 |
+
# x is (batch, channels, chunk_size, n_chunks)
|
287 |
+
batch, channels, chunk_size, n_chunks = x.size()
|
288 |
+
# we reshape to batch*chunk_size, channels, n_chunks
|
289 |
+
x = x.transpose(1, -1).reshape(batch * n_chunks, chunk_size, channels).transpose(1, -1)
|
290 |
+
x = module(x)
|
291 |
+
x = x.reshape(batch, n_chunks, channels, chunk_size).transpose(1, -1).transpose(1, 2)
|
292 |
+
return x
|
293 |
+
|
294 |
+
@staticmethod
|
295 |
+
def inter_process(x, module):
|
296 |
+
"""Performs inter-chunk processing.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
x (:class:`torch.Tensor`): spliced feature tensor of shape
|
300 |
+
(batch, channels, chunk_size, n_chunks).
|
301 |
+
module (:class:`torch.nn.Module`): module one wish to apply between
|
302 |
+
each chunk of the spliced feature tensor.
|
303 |
+
|
304 |
+
|
305 |
+
Returns:
|
306 |
+
x (:class:`torch.Tensor`): processed spliced feature tensor of shape
|
307 |
+
(batch, channels, chunk_size, n_chunks).
|
308 |
+
|
309 |
+
.. note:: the module should have the channel first convention and accept
|
310 |
+
a 3D tensor of shape (batch, channels, time).
|
311 |
+
"""
|
312 |
+
|
313 |
+
batch, channels, chunk_size, n_chunks = x.size()
|
314 |
+
x = x.transpose(1, 2).reshape(batch * chunk_size, channels, n_chunks)
|
315 |
+
x = module(x)
|
316 |
+
x = x.reshape(batch, chunk_size, channels, n_chunks).transpose(1, 2)
|
317 |
+
return x
|
DPTNet_eval/asteroid_test/filterbanks/__init__.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from .analytic_free_fb import AnalyticFreeFB
|
2 |
+
from .free_fb import FreeFB
|
3 |
+
from .enc_dec import Filterbank, Encoder, Decoder
|
4 |
+
|
5 |
+
|
6 |
+
def make_enc_dec(
|
7 |
+
fb_name,
|
8 |
+
n_filters,
|
9 |
+
kernel_size,
|
10 |
+
stride=None,
|
11 |
+
who_is_pinv=None,
|
12 |
+
padding=0,
|
13 |
+
output_padding=0,
|
14 |
+
**kwargs,
|
15 |
+
):
|
16 |
+
"""Creates congruent encoder and decoder from the same filterbank family.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
fb_name (str, className): Filterbank family from which to make encoder
|
20 |
+
and decoder. To choose among [``'free'``, ``'analytic_free'``,
|
21 |
+
``'param_sinc'``, ``'stft'``]. Can also be a class defined in a
|
22 |
+
submodule in this subpackade (e.g. :class:`~.FreeFB`).
|
23 |
+
n_filters (int): Number of filters.
|
24 |
+
kernel_size (int): Length of the filters.
|
25 |
+
stride (int, optional): Stride of the convolution.
|
26 |
+
If None (default), set to ``kernel_size // 2``.
|
27 |
+
who_is_pinv (str, optional): If `None`, no pseudo-inverse filters will
|
28 |
+
be used. If string (among [``'encoder'``, ``'decoder'``]), decides
|
29 |
+
which of ``Encoder`` or ``Decoder`` will be the pseudo inverse of
|
30 |
+
the other one.
|
31 |
+
padding (int): Zero-padding added to both sides of the input.
|
32 |
+
Passed to Encoder and Decoder.
|
33 |
+
output_padding (int): Additional size added to one side of the output shape.
|
34 |
+
Passed to Decoder.
|
35 |
+
**kwargs: Arguments which will be passed to the filterbank class
|
36 |
+
additionally to the usual `n_filters`, `kernel_size` and `stride`.
|
37 |
+
Depends on the filterbank family.
|
38 |
+
Returns:
|
39 |
+
:class:`.Encoder`, :class:`.Decoder`
|
40 |
+
"""
|
41 |
+
fb_class = get(fb_name)
|
42 |
+
|
43 |
+
if who_is_pinv in ["dec", "decoder"]:
|
44 |
+
fb = fb_class(n_filters, kernel_size, stride=stride, **kwargs)
|
45 |
+
enc = Encoder(fb, padding=padding)
|
46 |
+
# Decoder filterbank is pseudo inverse of encoder filterbank.
|
47 |
+
dec = Decoder.pinv_of(fb)
|
48 |
+
elif who_is_pinv in ["enc", "encoder"]:
|
49 |
+
fb = fb_class(n_filters, kernel_size, stride=stride, **kwargs)
|
50 |
+
dec = Decoder(fb, padding=padding, output_padding=output_padding)
|
51 |
+
# Encoder filterbank is pseudo inverse of decoder filterbank.
|
52 |
+
enc = Encoder.pinv_of(fb)
|
53 |
+
else:
|
54 |
+
fb = fb_class(n_filters, kernel_size, stride=stride, **kwargs)
|
55 |
+
enc = Encoder(fb, padding=padding)
|
56 |
+
# Filters between encoder and decoder should not be shared.
|
57 |
+
fb = fb_class(n_filters, kernel_size, stride=stride, **kwargs)
|
58 |
+
dec = Decoder(fb, padding=padding, output_padding=output_padding)
|
59 |
+
return enc, dec
|
60 |
+
|
61 |
+
|
62 |
+
def register_filterbank(custom_fb):
|
63 |
+
"""Register a custom filterbank, gettable with `filterbanks.get`.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
custom_fb: Custom filterbank to register.
|
67 |
+
|
68 |
+
"""
|
69 |
+
if custom_fb.__name__ in globals().keys() or custom_fb.__name__.lower() in globals().keys():
|
70 |
+
raise ValueError(f"Filterbank {custom_fb.__name__} already exists. Choose another name.")
|
71 |
+
globals().update({custom_fb.__name__: custom_fb})
|
72 |
+
|
73 |
+
|
74 |
+
def get(identifier):
|
75 |
+
"""Returns a filterbank class from a string. Returns its input if it
|
76 |
+
is callable (already a :class:`.Filterbank` for example).
|
77 |
+
|
78 |
+
Args:
|
79 |
+
identifier (str or Callable or None): the filterbank identifier.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
:class:`.Filterbank` or None
|
83 |
+
"""
|
84 |
+
if identifier is None:
|
85 |
+
return None
|
86 |
+
elif callable(identifier):
|
87 |
+
return identifier
|
88 |
+
elif isinstance(identifier, str):
|
89 |
+
cls = globals().get(identifier)
|
90 |
+
if cls is None:
|
91 |
+
raise ValueError("Could not interpret filterbank identifier: " + str(identifier))
|
92 |
+
return cls
|
93 |
+
else:
|
94 |
+
raise ValueError("Could not interpret filterbank identifier: " + str(identifier))
|
95 |
+
|
96 |
+
|
97 |
+
# Aliases.
|
98 |
+
free = FreeFB
|
99 |
+
|
100 |
+
# For the docs
|
101 |
+
__all__ = [
|
102 |
+
"Filterbank",
|
103 |
+
"Encoder",
|
104 |
+
"Decoder",
|
105 |
+
"FreeFB",
|
106 |
+
"make_enc_dec",
|
107 |
+
]
|
DPTNet_eval/asteroid_test/filterbanks/enc_dec.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class Filterbank(nn.Module):
|
8 |
+
"""Base Filterbank class.
|
9 |
+
Each subclass has to implement a `filters` property.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
n_filters (int): Number of filters.
|
13 |
+
kernel_size (int): Length of the filters.
|
14 |
+
stride (int, optional): Stride of the conv or transposed conv. (Hop size).
|
15 |
+
If None (default), set to ``kernel_size // 2``.
|
16 |
+
|
17 |
+
Attributes:
|
18 |
+
n_feats_out (int): Number of output filters.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, n_filters, kernel_size, stride=None):
|
22 |
+
super(Filterbank, self).__init__()
|
23 |
+
self.n_filters = n_filters
|
24 |
+
self.kernel_size = kernel_size
|
25 |
+
self.stride = stride if stride else self.kernel_size // 2
|
26 |
+
# If not specified otherwise in the filterbank's init, output
|
27 |
+
# number of features is equal to number of required filters.
|
28 |
+
self.n_feats_out = n_filters
|
29 |
+
|
30 |
+
@property
|
31 |
+
def filters(self):
|
32 |
+
""" Abstract method for filters. """
|
33 |
+
raise NotImplementedError
|
34 |
+
|
35 |
+
def get_config(self):
|
36 |
+
""" Returns dictionary of arguments to re-instantiate the class. """
|
37 |
+
config = {
|
38 |
+
"fb_name": self.__class__.__name__,
|
39 |
+
"n_filters": self.n_filters,
|
40 |
+
"kernel_size": self.kernel_size,
|
41 |
+
"stride": self.stride,
|
42 |
+
}
|
43 |
+
return config
|
44 |
+
|
45 |
+
|
46 |
+
class _EncDec(nn.Module):
|
47 |
+
"""Base private class for Encoder and Decoder.
|
48 |
+
|
49 |
+
Common parameters and methods.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
filterbank (:class:`Filterbank`): Filterbank instance. The filterbank
|
53 |
+
to use as an encoder or a decoder.
|
54 |
+
is_pinv (bool): Whether to be the pseudo inverse of filterbank.
|
55 |
+
|
56 |
+
Attributes:
|
57 |
+
filterbank (:class:`Filterbank`)
|
58 |
+
stride (int)
|
59 |
+
is_pinv (bool)
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(self, filterbank, is_pinv=False):
|
63 |
+
super(_EncDec, self).__init__()
|
64 |
+
self.filterbank = filterbank
|
65 |
+
self.stride = self.filterbank.stride
|
66 |
+
self.is_pinv = is_pinv
|
67 |
+
|
68 |
+
@property
|
69 |
+
def filters(self):
|
70 |
+
return self.filterbank.filters
|
71 |
+
|
72 |
+
def compute_filter_pinv(self, filters):
|
73 |
+
""" Computes pseudo inverse filterbank of given filters."""
|
74 |
+
scale = self.filterbank.stride / self.filterbank.kernel_size
|
75 |
+
shape = filters.shape
|
76 |
+
ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape)
|
77 |
+
# Compensate for the overlap-add.
|
78 |
+
return ifilt * scale
|
79 |
+
|
80 |
+
def get_filters(self):
|
81 |
+
""" Returns filters or pinv filters depending on `is_pinv` attribute """
|
82 |
+
if self.is_pinv:
|
83 |
+
return self.compute_filter_pinv(self.filters)
|
84 |
+
else:
|
85 |
+
return self.filters
|
86 |
+
|
87 |
+
def get_config(self):
|
88 |
+
""" Returns dictionary of arguments to re-instantiate the class."""
|
89 |
+
config = {"is_pinv": self.is_pinv}
|
90 |
+
base_config = self.filterbank.get_config()
|
91 |
+
return dict(list(base_config.items()) + list(config.items()))
|
92 |
+
|
93 |
+
|
94 |
+
class Encoder(_EncDec):
|
95 |
+
"""Encoder class.
|
96 |
+
|
97 |
+
Add encoding methods to Filterbank classes.
|
98 |
+
Not intended to be subclassed.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
filterbank (:class:`Filterbank`): The filterbank to use
|
102 |
+
as an encoder.
|
103 |
+
is_pinv (bool): Whether to be the pseudo inverse of filterbank.
|
104 |
+
as_conv1d (bool): Whether to behave like nn.Conv1d.
|
105 |
+
If True (default), forwarding input with shape (batch, 1, time)
|
106 |
+
will output a tensor of shape (batch, freq, conv_time).
|
107 |
+
If False, will output a tensor of shape (batch, 1, freq, conv_time).
|
108 |
+
padding (int): Zero-padding added to both sides of the input.
|
109 |
+
|
110 |
+
"""
|
111 |
+
|
112 |
+
def __init__(self, filterbank, is_pinv=False, as_conv1d=True, padding=0):
|
113 |
+
super(Encoder, self).__init__(filterbank, is_pinv=is_pinv)
|
114 |
+
self.as_conv1d = as_conv1d
|
115 |
+
self.n_feats_out = self.filterbank.n_feats_out
|
116 |
+
self.padding = padding
|
117 |
+
|
118 |
+
@classmethod
|
119 |
+
def pinv_of(cls, filterbank, **kwargs):
|
120 |
+
"""Returns an :class:`~.Encoder`, pseudo inverse of a
|
121 |
+
:class:`~.Filterbank` or :class:`~.Decoder`."""
|
122 |
+
if isinstance(filterbank, Filterbank):
|
123 |
+
return cls(filterbank, is_pinv=True, **kwargs)
|
124 |
+
elif isinstance(filterbank, Decoder):
|
125 |
+
return cls(filterbank.filterbank, is_pinv=True, **kwargs)
|
126 |
+
|
127 |
+
def forward(self, waveform):
|
128 |
+
"""Convolve input waveform with the filters from a filterbank.
|
129 |
+
Args:
|
130 |
+
waveform (:class:`torch.Tensor`): any tensor with samples along the
|
131 |
+
last dimension. The waveform representation with and
|
132 |
+
batch/channel etc.. dimension.
|
133 |
+
Returns:
|
134 |
+
:class:`torch.Tensor`: The corresponding TF domain signal.
|
135 |
+
|
136 |
+
Shapes:
|
137 |
+
>>> (time, ) --> (freq, conv_time)
|
138 |
+
>>> (batch, time) --> (batch, freq, conv_time) # Avoid
|
139 |
+
>>> if as_conv1d:
|
140 |
+
>>> (batch, 1, time) --> (batch, freq, conv_time)
|
141 |
+
>>> (batch, chan, time) --> (batch, chan, freq, conv_time)
|
142 |
+
>>> else:
|
143 |
+
>>> (batch, chan, time) --> (batch, chan, freq, conv_time)
|
144 |
+
>>> (batch, any, dim, time) --> (batch, any, dim, freq, conv_time)
|
145 |
+
"""
|
146 |
+
filters = self.get_filters()
|
147 |
+
if waveform.ndim == 1:
|
148 |
+
# Assumes 1D input with shape (time,)
|
149 |
+
# Output will be (freq, conv_time)
|
150 |
+
return F.conv1d(
|
151 |
+
waveform[None, None], filters, stride=self.stride, padding=self.padding
|
152 |
+
).squeeze()
|
153 |
+
elif waveform.ndim == 2:
|
154 |
+
# Assume 2D input with shape (batch or channels, time)
|
155 |
+
# Output will be (batch or channels, freq, conv_time)
|
156 |
+
warnings.warn(
|
157 |
+
"Input tensor was 2D. Applying the corresponding "
|
158 |
+
"Decoder to the current output will result in a 3D "
|
159 |
+
"tensor. This behaviours was introduced to match "
|
160 |
+
"Conv1D and ConvTranspose1D, please use 3D inputs "
|
161 |
+
"to avoid it. For example, this can be done with "
|
162 |
+
"input_tensor.unsqueeze(1)."
|
163 |
+
)
|
164 |
+
return F.conv1d(
|
165 |
+
waveform.unsqueeze(1), filters, stride=self.stride, padding=self.padding
|
166 |
+
)
|
167 |
+
elif waveform.ndim == 3:
|
168 |
+
batch, channels, time_len = waveform.shape
|
169 |
+
if channels == 1 and self.as_conv1d:
|
170 |
+
# That's the common single channel case (batch, 1, time)
|
171 |
+
# Output will be (batch, freq, stft_time), behaves as Conv1D
|
172 |
+
return F.conv1d(waveform, filters, stride=self.stride, padding=self.padding)
|
173 |
+
else:
|
174 |
+
# Return batched convolution, input is (batch, 3, time),
|
175 |
+
# output will be (batch, 3, freq, conv_time).
|
176 |
+
# Useful for multichannel transforms
|
177 |
+
# If as_conv1d is false, (batch, 1, time) will output
|
178 |
+
# (batch, 1, freq, conv_time), useful for consistency.
|
179 |
+
return self.batch_1d_conv(waveform, filters)
|
180 |
+
else: # waveform.ndim > 3
|
181 |
+
# This is to compute "multi"multichannel convolution.
|
182 |
+
# Input can be (*, time), output will be (*, freq, conv_time)
|
183 |
+
return self.batch_1d_conv(waveform, filters)
|
184 |
+
|
185 |
+
def batch_1d_conv(self, inp, filters):
|
186 |
+
# Here we perform multichannel / multi-source convolution. Ou
|
187 |
+
# Output should be (batch, channels, freq, conv_time)
|
188 |
+
batched_conv = F.conv1d(
|
189 |
+
inp.view(-1, 1, inp.shape[-1]), filters, stride=self.stride, padding=self.padding
|
190 |
+
)
|
191 |
+
output_shape = inp.shape[:-1] + batched_conv.shape[-2:]
|
192 |
+
return batched_conv.view(output_shape)
|
193 |
+
|
194 |
+
|
195 |
+
class Decoder(_EncDec):
|
196 |
+
"""Decoder class.
|
197 |
+
|
198 |
+
Add decoding methods to Filterbank classes.
|
199 |
+
Not intended to be subclassed.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
filterbank (:class:`Filterbank`): The filterbank to use as an decoder.
|
203 |
+
is_pinv (bool): Whether to be the pseudo inverse of filterbank.
|
204 |
+
padding (int): Zero-padding added to both sides of the input.
|
205 |
+
output_padding (int): Additional size added to one side of the
|
206 |
+
output shape.
|
207 |
+
|
208 |
+
Notes
|
209 |
+
`padding` and `output_padding` arguments are directly passed to
|
210 |
+
F.conv_transpose1d.
|
211 |
+
"""
|
212 |
+
|
213 |
+
def __init__(self, filterbank, is_pinv=False, padding=0, output_padding=0):
|
214 |
+
super().__init__(filterbank, is_pinv=is_pinv)
|
215 |
+
self.padding = padding
|
216 |
+
self.output_padding = output_padding
|
217 |
+
|
218 |
+
@classmethod
|
219 |
+
def pinv_of(cls, filterbank):
|
220 |
+
""" Returns an Decoder, pseudo inverse of a filterbank or Encoder."""
|
221 |
+
if isinstance(filterbank, Filterbank):
|
222 |
+
return cls(filterbank, is_pinv=True)
|
223 |
+
elif isinstance(filterbank, Encoder):
|
224 |
+
return cls(filterbank.filterbank, is_pinv=True)
|
225 |
+
|
226 |
+
def forward(self, spec):
|
227 |
+
"""Applies transposed convolution to a TF representation.
|
228 |
+
|
229 |
+
This is equivalent to overlap-add.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
spec (:class:`torch.Tensor`): 3D or 4D Tensor. The TF
|
233 |
+
representation. (Output of :func:`Encoder.forward`).
|
234 |
+
Returns:
|
235 |
+
:class:`torch.Tensor`: The corresponding time domain signal.
|
236 |
+
"""
|
237 |
+
filters = self.get_filters()
|
238 |
+
if spec.ndim == 2:
|
239 |
+
# Input is (freq, conv_time), output is (time)
|
240 |
+
return F.conv_transpose1d(
|
241 |
+
spec.unsqueeze(0),
|
242 |
+
filters,
|
243 |
+
stride=self.stride,
|
244 |
+
padding=self.padding,
|
245 |
+
output_padding=self.output_padding,
|
246 |
+
).squeeze()
|
247 |
+
if spec.ndim == 3:
|
248 |
+
# Input is (batch, freq, conv_time), output is (batch, 1, time)
|
249 |
+
return F.conv_transpose1d(
|
250 |
+
spec,
|
251 |
+
filters,
|
252 |
+
stride=self.stride,
|
253 |
+
padding=self.padding,
|
254 |
+
output_padding=self.output_padding,
|
255 |
+
)
|
256 |
+
elif spec.ndim > 3:
|
257 |
+
# Multiply all the left dimensions together and group them in the
|
258 |
+
# batch. Make the convolution and restore.
|
259 |
+
view_as = (-1,) + spec.shape[-2:]
|
260 |
+
out = F.conv_transpose1d(
|
261 |
+
spec.view(view_as),
|
262 |
+
filters,
|
263 |
+
stride=self.stride,
|
264 |
+
padding=self.padding,
|
265 |
+
output_padding=self.output_padding,
|
266 |
+
)
|
267 |
+
return out.view(spec.shape[:-2] + (-1,))
|
DPTNet_eval/asteroid_test/filterbanks/free_fb.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from .enc_dec import Filterbank
|
4 |
+
|
5 |
+
|
6 |
+
class FreeFB(Filterbank):
|
7 |
+
"""Free filterbank without any constraints. Equivalent to
|
8 |
+
:class:`nn.Conv1d`.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
n_filters (int): Number of filters.
|
12 |
+
kernel_size (int): Length of the filters.
|
13 |
+
stride (int, optional): Stride of the convolution.
|
14 |
+
If None (default), set to ``kernel_size // 2``.
|
15 |
+
|
16 |
+
Attributes:
|
17 |
+
n_feats_out (int): Number of output filters.
|
18 |
+
|
19 |
+
References:
|
20 |
+
[1] : "Filterbank design for end-to-end speech separation".
|
21 |
+
Submitted to ICASSP 2020. Manuel Pariente, Samuele Cornell,
|
22 |
+
Antoine Deleforge, Emmanuel Vincent.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, n_filters, kernel_size, stride=None, **kwargs):
|
26 |
+
super(FreeFB, self).__init__(n_filters, kernel_size, stride=stride)
|
27 |
+
self._filters = nn.Parameter(torch.ones(n_filters, 1, kernel_size))
|
28 |
+
for p in self.parameters():
|
29 |
+
nn.init.xavier_normal_(p)
|
30 |
+
|
31 |
+
@property
|
32 |
+
def filters(self):
|
33 |
+
return self._filters
|
DPTNet_eval/asteroid_test/masknn/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from .convolutional import TDConvNet, TDConvNetpp, SuDORMRF, SuDORMRFImproved
|
2 |
+
# from .recurrent import DPRNN, LSTMMasker
|
3 |
+
from .attention import DPTransformer
|
4 |
+
|
5 |
+
__all__ = [
|
6 |
+
# "TDConvNet",
|
7 |
+
# "DPRNN",
|
8 |
+
"DPTransformer",
|
9 |
+
# "LSTMMasker",
|
10 |
+
# "SuDORMRF",
|
11 |
+
# "SuDORMRFImproved",
|
12 |
+
]
|
DPTNet_eval/asteroid_test/masknn/activations.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class Swish(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super(Swish, self).__init__()
|
9 |
+
|
10 |
+
def forward(self, x):
|
11 |
+
return x * torch.sigmoid(x)
|
12 |
+
|
13 |
+
|
14 |
+
def linear():
|
15 |
+
return nn.Identity()
|
16 |
+
|
17 |
+
|
18 |
+
def relu():
|
19 |
+
return nn.ReLU()
|
20 |
+
|
21 |
+
|
22 |
+
def prelu():
|
23 |
+
return nn.PReLU()
|
24 |
+
|
25 |
+
|
26 |
+
def leaky_relu():
|
27 |
+
return nn.LeakyReLU()
|
28 |
+
|
29 |
+
|
30 |
+
def sigmoid():
|
31 |
+
return nn.Sigmoid()
|
32 |
+
|
33 |
+
|
34 |
+
def softmax(dim=None):
|
35 |
+
return nn.Softmax(dim=dim)
|
36 |
+
|
37 |
+
|
38 |
+
def tanh():
|
39 |
+
return nn.Tanh()
|
40 |
+
|
41 |
+
|
42 |
+
def gelu():
|
43 |
+
return nn.GELU()
|
44 |
+
|
45 |
+
|
46 |
+
def swish():
|
47 |
+
return Swish()
|
48 |
+
|
49 |
+
|
50 |
+
def register_activation(custom_act):
|
51 |
+
"""Register a custom activation, gettable with `activation.get`.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
custom_act: Custom activation function to register.
|
55 |
+
|
56 |
+
"""
|
57 |
+
if custom_act.__name__ in globals().keys() or custom_act.__name__.lower() in globals().keys():
|
58 |
+
raise ValueError(f"Activation {custom_act.__name__} already exists. Choose another name.")
|
59 |
+
globals().update({custom_act.__name__: custom_act})
|
60 |
+
|
61 |
+
|
62 |
+
def get(identifier):
|
63 |
+
"""Returns an activation function from a string. Returns its input if it
|
64 |
+
is callable (already an activation for example).
|
65 |
+
|
66 |
+
Args:
|
67 |
+
identifier (str or Callable or None): the activation identifier.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
:class:`nn.Module` or None
|
71 |
+
"""
|
72 |
+
if identifier is None:
|
73 |
+
return None
|
74 |
+
elif callable(identifier):
|
75 |
+
return identifier
|
76 |
+
elif isinstance(identifier, str):
|
77 |
+
cls = globals().get(identifier)
|
78 |
+
if cls is None:
|
79 |
+
raise ValueError("Could not interpret activation identifier: " + str(identifier))
|
80 |
+
return cls
|
81 |
+
else:
|
82 |
+
raise ValueError("Could not interpret activation identifier: " + str(identifier))
|
DPTNet_eval/asteroid_test/masknn/attention.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import ceil
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn.modules.activation import MultiheadAttention
|
6 |
+
from ..masknn import activations, norms
|
7 |
+
import torch
|
8 |
+
from ..dsp.overlap_add import DualPathProcessing
|
9 |
+
|
10 |
+
import inspect
|
11 |
+
|
12 |
+
|
13 |
+
class ImprovedTransformedLayer(nn.Module):
|
14 |
+
"""
|
15 |
+
Improved Transformer module as used in [1].
|
16 |
+
It is Multi-Head self-attention followed by LSTM, activation and linear projection layer.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
embed_dim (int): Number of input channels.
|
20 |
+
n_heads (int): Number of attention heads.
|
21 |
+
dim_ff (int): Number of neurons in the RNNs cell state.
|
22 |
+
Defaults to 256. RNN here replaces standard FF linear layer in plain Transformer.
|
23 |
+
dropout (float, optional): Dropout ratio, must be in [0,1].
|
24 |
+
activation (str, optional): activation function applied at the output of RNN.
|
25 |
+
bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN
|
26 |
+
(Intra-Chunk is always bidirectional).
|
27 |
+
norm_type (str, optional): Type of normalization to use.
|
28 |
+
|
29 |
+
References:
|
30 |
+
[1] Chen, Jingjing, Qirong Mao, and Dong Liu.
|
31 |
+
"Dual-Path Transformer Network: Direct Context-Aware Modeling for End-to-End Monaural Speech Separation."
|
32 |
+
arXiv preprint arXiv:2007.13975 (2020).
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
embed_dim,
|
38 |
+
n_heads,
|
39 |
+
dim_ff,
|
40 |
+
dropout=0.0,
|
41 |
+
activation="relu",
|
42 |
+
bidirectional=True,
|
43 |
+
norm="gLN",
|
44 |
+
):
|
45 |
+
super(ImprovedTransformedLayer, self).__init__()
|
46 |
+
|
47 |
+
self.mha = MultiheadAttention(embed_dim, n_heads, dropout=dropout)
|
48 |
+
# self.linear_first = nn.Linear(embed_dim, 2 * dim_ff) # Added by Kay. 20201119
|
49 |
+
self.dropout = nn.Dropout(dropout)
|
50 |
+
self.recurrent = nn.LSTM(embed_dim, dim_ff, bidirectional=bidirectional, batch_first=True)
|
51 |
+
ff_inner_dim = 2 * dim_ff if bidirectional else dim_ff
|
52 |
+
self.linear = nn.Linear(ff_inner_dim, embed_dim)
|
53 |
+
self.activation = activations.get(activation)()
|
54 |
+
self.norm_mha = norms.get(norm)(embed_dim)
|
55 |
+
self.norm_ff = norms.get(norm)(embed_dim)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
tomha = x.permute(2, 0, 1)
|
59 |
+
# x is batch, channels, seq_len
|
60 |
+
# mha is seq_len, batch, channels
|
61 |
+
# self-attention is applied
|
62 |
+
out = self.mha(tomha, tomha, tomha)[0]
|
63 |
+
x = self.dropout(out.permute(1, 2, 0)) + x
|
64 |
+
x = self.norm_mha(x)
|
65 |
+
|
66 |
+
# lstm is applied
|
67 |
+
out = self.linear(self.dropout(self.activation(self.recurrent(x.transpose(1, -1))[0])))
|
68 |
+
x = self.dropout(out.transpose(1, -1)) + x
|
69 |
+
return self.norm_ff(x)
|
70 |
+
|
71 |
+
''' version 0.3.4
|
72 |
+
def forward(self, x):
|
73 |
+
x = x.transpose(1, -1)
|
74 |
+
# x is batch, seq_len, channels
|
75 |
+
# self-attention is applied
|
76 |
+
out = self.mha(x, x, x)[0]
|
77 |
+
x = self.dropout(out) + x
|
78 |
+
x = self.norm_mha(x.transpose(1, -1)).transpose(1, -1)
|
79 |
+
|
80 |
+
# lstm is applied
|
81 |
+
out = self.linear(self.dropout(self.activation(self.recurrent(x)[0])))
|
82 |
+
# out = self.linear(self.dropout(self.activation(self.linear_first(x)[0])))
|
83 |
+
x = self.dropout(out) + x
|
84 |
+
return self.norm_ff(x.transpose(1, -1))
|
85 |
+
'''
|
86 |
+
|
87 |
+
|
88 |
+
class DPTransformer(nn.Module):
|
89 |
+
"""Dual-path Transformer introduced in [1].
|
90 |
+
|
91 |
+
Args:
|
92 |
+
in_chan (int): Number of input filters.
|
93 |
+
n_src (int): Number of masks to estimate.
|
94 |
+
n_heads (int): Number of attention heads.
|
95 |
+
ff_hid (int): Number of neurons in the RNNs cell state.
|
96 |
+
Defaults to 256.
|
97 |
+
chunk_size (int): window size of overlap and add processing.
|
98 |
+
Defaults to 100.
|
99 |
+
hop_size (int or None): hop size (stride) of overlap and add processing.
|
100 |
+
Default to `chunk_size // 2` (50% overlap).
|
101 |
+
n_repeats (int): Number of repeats. Defaults to 6.
|
102 |
+
norm_type (str, optional): Type of normalization to use.
|
103 |
+
ff_activation (str, optional): activation function applied at the output of RNN.
|
104 |
+
mask_act (str, optional): Which non-linear function to generate mask.
|
105 |
+
bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN
|
106 |
+
(Intra-Chunk is always bidirectional).
|
107 |
+
dropout (float, optional): Dropout ratio, must be in [0,1].
|
108 |
+
|
109 |
+
References
|
110 |
+
[1] Chen, Jingjing, Qirong Mao, and Dong Liu. "Dual-Path Transformer
|
111 |
+
Network: Direct Context-Aware Modeling for End-to-End Monaural Speech Separation."
|
112 |
+
arXiv (2020).
|
113 |
+
"""
|
114 |
+
|
115 |
+
def __init__(
|
116 |
+
self,
|
117 |
+
in_chan,
|
118 |
+
n_src,
|
119 |
+
n_heads=4,
|
120 |
+
ff_hid=256,
|
121 |
+
chunk_size=100,
|
122 |
+
hop_size=None,
|
123 |
+
n_repeats=6,
|
124 |
+
norm_type="gLN",
|
125 |
+
ff_activation="relu",
|
126 |
+
mask_act="relu",
|
127 |
+
bidirectional=True,
|
128 |
+
dropout=0,
|
129 |
+
):
|
130 |
+
super(DPTransformer, self).__init__()
|
131 |
+
self.in_chan = in_chan
|
132 |
+
self.n_src = n_src
|
133 |
+
self.n_heads = n_heads
|
134 |
+
self.ff_hid = ff_hid
|
135 |
+
self.chunk_size = chunk_size
|
136 |
+
hop_size = hop_size if hop_size is not None else chunk_size // 2
|
137 |
+
self.hop_size = hop_size
|
138 |
+
self.n_repeats = n_repeats
|
139 |
+
self.n_src = n_src
|
140 |
+
self.norm_type = norm_type
|
141 |
+
self.ff_activation = ff_activation
|
142 |
+
self.mask_act = mask_act
|
143 |
+
self.bidirectional = bidirectional
|
144 |
+
self.dropout = dropout
|
145 |
+
|
146 |
+
# version 0.3.4
|
147 |
+
# self.in_norm = norms.get(norm_type)(in_chan)
|
148 |
+
self.mha_in_dim = ceil(self.in_chan / self.n_heads) * self.n_heads
|
149 |
+
if self.in_chan % self.n_heads != 0:
|
150 |
+
warnings.warn(
|
151 |
+
f"DPTransformer input dim ({self.in_chan}) is not a multiple of the number of "
|
152 |
+
f"heads ({self.n_heads}). Adding extra linear layer at input to accomodate "
|
153 |
+
f"(size [{self.in_chan} x {self.mha_in_dim}])"
|
154 |
+
)
|
155 |
+
self.input_layer = nn.Linear(self.in_chan, self.mha_in_dim)
|
156 |
+
else:
|
157 |
+
self.input_layer = None
|
158 |
+
|
159 |
+
self.in_norm = norms.get(norm_type)(self.mha_in_dim)
|
160 |
+
self.ola = DualPathProcessing(self.chunk_size, self.hop_size)
|
161 |
+
|
162 |
+
# Succession of DPRNNBlocks.
|
163 |
+
self.layers = nn.ModuleList([])
|
164 |
+
for x in range(self.n_repeats):
|
165 |
+
self.layers.append(
|
166 |
+
nn.ModuleList(
|
167 |
+
[
|
168 |
+
ImprovedTransformedLayer(
|
169 |
+
self.mha_in_dim,
|
170 |
+
self.n_heads,
|
171 |
+
self.ff_hid,
|
172 |
+
self.dropout,
|
173 |
+
self.ff_activation,
|
174 |
+
True,
|
175 |
+
self.norm_type,
|
176 |
+
),
|
177 |
+
ImprovedTransformedLayer(
|
178 |
+
self.mha_in_dim,
|
179 |
+
self.n_heads,
|
180 |
+
self.ff_hid,
|
181 |
+
self.dropout,
|
182 |
+
self.ff_activation,
|
183 |
+
self.bidirectional,
|
184 |
+
self.norm_type,
|
185 |
+
),
|
186 |
+
]
|
187 |
+
)
|
188 |
+
)
|
189 |
+
net_out_conv = nn.Conv2d(self.mha_in_dim, n_src * self.in_chan, 1)
|
190 |
+
self.first_out = nn.Sequential(nn.PReLU(), net_out_conv)
|
191 |
+
# Gating and masking in 2D space (after fold)
|
192 |
+
self.net_out = nn.Sequential(nn.Conv1d(self.in_chan, self.in_chan, 1), nn.Tanh())
|
193 |
+
self.net_gate = nn.Sequential(nn.Conv1d(self.in_chan, self.in_chan, 1), nn.Sigmoid())
|
194 |
+
|
195 |
+
# Get activation function.
|
196 |
+
mask_nl_class = activations.get(mask_act)
|
197 |
+
# For softmax, feed the source dimension.
|
198 |
+
if has_arg(mask_nl_class, "dim"):
|
199 |
+
self.output_act = mask_nl_class(dim=1)
|
200 |
+
else:
|
201 |
+
self.output_act = mask_nl_class()
|
202 |
+
|
203 |
+
def forward(self, mixture_w):
|
204 |
+
r"""Forward.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
mixture_w (:class:`torch.Tensor`): Tensor of shape $(batch, nfilters, nframes)$
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
:class:`torch.Tensor`: estimated mask of shape $(batch, nsrc, nfilters, nframes)$
|
211 |
+
"""
|
212 |
+
if self.input_layer is not None:
|
213 |
+
mixture_w = self.input_layer(mixture_w.transpose(1, 2)).transpose(1, 2)
|
214 |
+
mixture_w = self.in_norm(mixture_w) # [batch, bn_chan, n_frames]
|
215 |
+
n_orig_frames = mixture_w.shape[-1]
|
216 |
+
|
217 |
+
mixture_w = self.ola.unfold(mixture_w)
|
218 |
+
batch, n_filters, self.chunk_size, n_chunks = mixture_w.size()
|
219 |
+
|
220 |
+
for layer_idx in range(len(self.layers)):
|
221 |
+
intra, inter = self.layers[layer_idx]
|
222 |
+
mixture_w = self.ola.intra_process(mixture_w, intra)
|
223 |
+
mixture_w = self.ola.inter_process(mixture_w, inter)
|
224 |
+
|
225 |
+
output = self.first_out(mixture_w)
|
226 |
+
output = output.reshape(batch * self.n_src, self.in_chan, self.chunk_size, n_chunks)
|
227 |
+
output = self.ola.fold(output, output_size=n_orig_frames)
|
228 |
+
|
229 |
+
output = self.net_out(output) * self.net_gate(output)
|
230 |
+
# Compute mask
|
231 |
+
output = output.reshape(batch, self.n_src, self.in_chan, -1)
|
232 |
+
est_mask = self.output_act(output)
|
233 |
+
return est_mask
|
234 |
+
|
235 |
+
def get_config(self):
|
236 |
+
config = {
|
237 |
+
"in_chan": self.in_chan,
|
238 |
+
"ff_hid": self.ff_hid,
|
239 |
+
"n_heads": self.n_heads,
|
240 |
+
"chunk_size": self.chunk_size,
|
241 |
+
"hop_size": self.hop_size,
|
242 |
+
"n_repeats": self.n_repeats,
|
243 |
+
"n_src": self.n_src,
|
244 |
+
"norm_type": self.norm_type,
|
245 |
+
"ff_activation": self.ff_activation,
|
246 |
+
"mask_act": self.mask_act,
|
247 |
+
"bidirectional": self.bidirectional,
|
248 |
+
"dropout": self.dropout,
|
249 |
+
}
|
250 |
+
return config
|
251 |
+
|
252 |
+
|
253 |
+
def has_arg(fn, name):
|
254 |
+
"""Checks if a callable accepts a given keyword argument.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
fn (callable): Callable to inspect.
|
258 |
+
name (str): Check if `fn` can be called with `name` as a keyword
|
259 |
+
argument.
|
260 |
+
|
261 |
+
Returns:
|
262 |
+
bool: whether `fn` accepts a `name` keyword argument.
|
263 |
+
"""
|
264 |
+
signature = inspect.signature(fn)
|
265 |
+
parameter = signature.parameters.get(name)
|
266 |
+
if parameter is None:
|
267 |
+
return False
|
268 |
+
return parameter.kind in (
|
269 |
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
270 |
+
inspect.Parameter.KEYWORD_ONLY,
|
271 |
+
)
|
DPTNet_eval/asteroid_test/masknn/norms.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
5 |
+
|
6 |
+
EPS = 1e-8
|
7 |
+
|
8 |
+
|
9 |
+
class _LayerNorm(nn.Module):
|
10 |
+
"""Layer Normalization base class."""
|
11 |
+
|
12 |
+
def __init__(self, channel_size):
|
13 |
+
super(_LayerNorm, self).__init__()
|
14 |
+
self.channel_size = channel_size
|
15 |
+
self.gamma = nn.Parameter(torch.ones(channel_size), requires_grad=True)
|
16 |
+
self.beta = nn.Parameter(torch.zeros(channel_size), requires_grad=True)
|
17 |
+
|
18 |
+
def apply_gain_and_bias(self, normed_x):
|
19 |
+
""" Assumes input of size `[batch, chanel, *]`. """
|
20 |
+
return (self.gamma * normed_x.transpose(1, -1) + self.beta).transpose(1, -1)
|
21 |
+
|
22 |
+
|
23 |
+
class GlobLN(_LayerNorm):
|
24 |
+
"""Global Layer Normalization (globLN)."""
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
"""Applies forward pass.
|
28 |
+
|
29 |
+
Works for any input size > 2D.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
x (:class:`torch.Tensor`): Shape `[batch, chan, *]`
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
:class:`torch.Tensor`: gLN_x `[batch, chan, *]`
|
36 |
+
"""
|
37 |
+
dims = list(range(1, len(x.shape)))
|
38 |
+
mean = x.mean(dim=dims, keepdim=True)
|
39 |
+
var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True)
|
40 |
+
return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
|
41 |
+
|
42 |
+
|
43 |
+
class ChanLN(_LayerNorm):
|
44 |
+
"""Channel-wise Layer Normalization (chanLN)."""
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
"""Applies forward pass.
|
48 |
+
|
49 |
+
Works for any input size > 2D.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
x (:class:`torch.Tensor`): `[batch, chan, *]`
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
:class:`torch.Tensor`: chanLN_x `[batch, chan, *]`
|
56 |
+
"""
|
57 |
+
mean = torch.mean(x, dim=1, keepdim=True)
|
58 |
+
var = torch.var(x, dim=1, keepdim=True, unbiased=False)
|
59 |
+
return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
|
60 |
+
|
61 |
+
|
62 |
+
class CumLN(_LayerNorm):
|
63 |
+
"""Cumulative Global layer normalization(cumLN)."""
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
"""
|
67 |
+
|
68 |
+
Args:
|
69 |
+
x (:class:`torch.Tensor`): Shape `[batch, channels, length]`
|
70 |
+
Returns:
|
71 |
+
:class:`torch.Tensor`: cumLN_x `[batch, channels, length]`
|
72 |
+
"""
|
73 |
+
batch, chan, spec_len = x.size()
|
74 |
+
cum_sum = torch.cumsum(x.sum(1, keepdim=True), dim=-1)
|
75 |
+
cum_pow_sum = torch.cumsum(x.pow(2).sum(1, keepdim=True), dim=-1)
|
76 |
+
cnt = torch.arange(start=chan, end=chan * (spec_len + 1), step=chan, dtype=x.dtype).view(
|
77 |
+
1, 1, -1
|
78 |
+
)
|
79 |
+
cum_mean = cum_sum / cnt
|
80 |
+
cum_var = cum_pow_sum - cum_mean.pow(2)
|
81 |
+
return self.apply_gain_and_bias((x - cum_mean) / (cum_var + EPS).sqrt())
|
82 |
+
|
83 |
+
|
84 |
+
class FeatsGlobLN(_LayerNorm):
|
85 |
+
"""feature-wise global Layer Normalization (FeatsGlobLN).
|
86 |
+
Applies normalization over frames for each channel."""
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
"""Applies forward pass.
|
90 |
+
|
91 |
+
Works for any input size > 2D.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
x (:class:`torch.Tensor`): `[batch, chan, time]`
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
:class:`torch.Tensor`: chanLN_x `[batch, chan, time]`
|
98 |
+
"""
|
99 |
+
|
100 |
+
stop = len(x.size())
|
101 |
+
dims = list(range(2, stop))
|
102 |
+
|
103 |
+
mean = torch.mean(x, dim=dims, keepdim=True)
|
104 |
+
var = torch.var(x, dim=dims, keepdim=True, unbiased=False)
|
105 |
+
return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
|
106 |
+
|
107 |
+
|
108 |
+
class BatchNorm(_BatchNorm):
|
109 |
+
"""Wrapper class for pytorch BatchNorm1D and BatchNorm2D"""
|
110 |
+
|
111 |
+
def _check_input_dim(self, input):
|
112 |
+
if input.dim() < 2 or input.dim() > 4:
|
113 |
+
raise ValueError("expected 4D or 3D input (got {}D input)".format(input.dim()))
|
114 |
+
|
115 |
+
|
116 |
+
# Aliases.
|
117 |
+
gLN = GlobLN
|
118 |
+
fgLN = FeatsGlobLN
|
119 |
+
cLN = ChanLN
|
120 |
+
cgLN = CumLN
|
121 |
+
bN = BatchNorm
|
122 |
+
|
123 |
+
|
124 |
+
def register_norm(custom_norm):
|
125 |
+
"""Register a custom norm, gettable with `norms.get`.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
custom_norm: Custom norm to register.
|
129 |
+
|
130 |
+
"""
|
131 |
+
if custom_norm.__name__ in globals().keys() or custom_norm.__name__.lower() in globals().keys():
|
132 |
+
raise ValueError(f"Norm {custom_norm.__name__} already exists. Choose another name.")
|
133 |
+
globals().update({custom_norm.__name__: custom_norm})
|
134 |
+
|
135 |
+
|
136 |
+
def get(identifier):
|
137 |
+
"""Returns a norm class from a string. Returns its input if it
|
138 |
+
is callable (already a :class:`._LayerNorm` for example).
|
139 |
+
|
140 |
+
Args:
|
141 |
+
identifier (str or Callable or None): the norm identifier.
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
:class:`._LayerNorm` or None
|
145 |
+
"""
|
146 |
+
if identifier is None:
|
147 |
+
return None
|
148 |
+
elif callable(identifier):
|
149 |
+
return identifier
|
150 |
+
elif isinstance(identifier, str):
|
151 |
+
cls = globals().get(identifier)
|
152 |
+
if cls is None:
|
153 |
+
raise ValueError("Could not interpret normalization identifier: " + str(identifier))
|
154 |
+
return cls
|
155 |
+
else:
|
156 |
+
raise ValueError("Could not interpret normalization identifier: " + str(identifier))
|
DPTNet_eval/asteroid_test/models/__init__.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Models
|
2 |
+
# from .conv_tasnet import ConvTasNet
|
3 |
+
# from .dccrnet import DCCRNet
|
4 |
+
# from .dcunet import DCUNet
|
5 |
+
# from .dprnn_tasnet import DPRNNTasNet
|
6 |
+
# from .sudormrf import SuDORMRFImprovedNet, SuDORMRFNet
|
7 |
+
from .dptnet import DPTNet
|
8 |
+
# from .lstm_tasnet import LSTMTasNet
|
9 |
+
# from .demask import DeMask
|
10 |
+
|
11 |
+
# Sharing-related
|
12 |
+
# from .publisher import save_publishable, upload_publishable
|
13 |
+
|
14 |
+
__all__ = [
|
15 |
+
# "ConvTasNet",
|
16 |
+
# "DPRNNTasNet",
|
17 |
+
# "SuDORMRFImprovedNet",
|
18 |
+
# "SuDORMRFNet",
|
19 |
+
"DPTNet",
|
20 |
+
# "LSTMTasNet",
|
21 |
+
# "DeMask",
|
22 |
+
# "DCUNet",
|
23 |
+
# "DCCRNet",
|
24 |
+
# "save_publishable",
|
25 |
+
# "upload_publishable",
|
26 |
+
]
|
27 |
+
|
28 |
+
|
29 |
+
def register_model(custom_model):
|
30 |
+
"""Register a custom model, gettable with `models.get`.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
custom_model: Custom model to register.
|
34 |
+
|
35 |
+
"""
|
36 |
+
if (
|
37 |
+
custom_model.__name__ in globals().keys()
|
38 |
+
or custom_model.__name__.lower() in globals().keys()
|
39 |
+
):
|
40 |
+
raise ValueError(f"Model {custom_model.__name__} already exists. Choose another name.")
|
41 |
+
globals().update({custom_model.__name__: custom_model})
|
42 |
+
|
43 |
+
|
44 |
+
def get(identifier):
|
45 |
+
"""Returns an model class from a string (case-insensitive).
|
46 |
+
|
47 |
+
Args:
|
48 |
+
identifier (str): the model name.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
:class:`torch.nn.Module`
|
52 |
+
"""
|
53 |
+
if isinstance(identifier, str):
|
54 |
+
to_get = {k.lower(): v for k, v in globals().items()}
|
55 |
+
cls = to_get.get(identifier.lower())
|
56 |
+
if cls is None:
|
57 |
+
raise ValueError(f"Could not interpret model name : {str(identifier)}")
|
58 |
+
return cls
|
59 |
+
raise ValueError(f"Could not interpret model name : {str(identifier)}")
|
DPTNet_eval/asteroid_test/models/base_models.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from ..masknn import activations
|
9 |
+
from ..utils.torch_utils import pad_x_to_y
|
10 |
+
|
11 |
+
|
12 |
+
def _unsqueeze_to_3d(x):
|
13 |
+
if x.ndim == 1:
|
14 |
+
return x.reshape(1, 1, -1)
|
15 |
+
elif x.ndim == 2:
|
16 |
+
return x.unsqueeze(1)
|
17 |
+
else:
|
18 |
+
return x
|
19 |
+
|
20 |
+
|
21 |
+
class BaseModel(nn.Module):
|
22 |
+
def __init__(self):
|
23 |
+
print("initialize BaseModel")
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
def forward(self, *args, **kwargs):
|
27 |
+
raise NotImplementedError
|
28 |
+
|
29 |
+
@torch.no_grad()
|
30 |
+
def separate(self, wav, output_dir=None, force_overwrite=False, **kwargs):
|
31 |
+
"""Infer separated sources from input waveforms.
|
32 |
+
Also supports filenames.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
wav (Union[torch.Tensor, numpy.ndarray, str]): waveform array/tensor.
|
36 |
+
Shape: 1D, 2D or 3D tensor, time last.
|
37 |
+
output_dir (str): path to save all the wav files. If None,
|
38 |
+
estimated sources will be saved next to the original ones.
|
39 |
+
force_overwrite (bool): whether to overwrite existing files.
|
40 |
+
**kwargs: keyword arguments to be passed to `_separate`.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
Union[torch.Tensor, numpy.ndarray, None], the estimated sources.
|
44 |
+
(batch, n_src, time) or (n_src, time) w/o batch dim.
|
45 |
+
|
46 |
+
.. note::
|
47 |
+
By default, `separate` calls `_separate` which calls `forward`.
|
48 |
+
For models whose `forward` doesn't return waveform tensors,
|
49 |
+
overwrite `_separate` to return waveform tensors.
|
50 |
+
"""
|
51 |
+
if isinstance(wav, str):
|
52 |
+
self.file_separate(
|
53 |
+
wav, output_dir=output_dir, force_overwrite=force_overwrite, **kwargs
|
54 |
+
)
|
55 |
+
elif isinstance(wav, np.ndarray):
|
56 |
+
print("is ndarray")
|
57 |
+
# import pdb ; pdb.set_trace()
|
58 |
+
return self.numpy_separate(wav, **kwargs)
|
59 |
+
elif isinstance(wav, torch.Tensor):
|
60 |
+
print("is torch.Tensor")
|
61 |
+
return self.torch_separate(wav, **kwargs)
|
62 |
+
else:
|
63 |
+
raise ValueError(
|
64 |
+
f"Only support filenames, numpy arrays and torch tensors, received {type(wav)}"
|
65 |
+
)
|
66 |
+
|
67 |
+
def torch_separate(self, wav: torch.Tensor, **kwargs) -> torch.Tensor:
|
68 |
+
""" Core logic of `separate`."""
|
69 |
+
# Handle device placement
|
70 |
+
input_device = wav.device
|
71 |
+
model_device = next(self.parameters()).device
|
72 |
+
wav = wav.to(model_device)
|
73 |
+
# Forward
|
74 |
+
out_wavs = self._separate(wav, **kwargs)
|
75 |
+
|
76 |
+
# FIXME: for now this is the best we can do.
|
77 |
+
out_wavs *= wav.abs().sum() / (out_wavs.abs().sum())
|
78 |
+
|
79 |
+
# Back to input device (and numpy if necessary)
|
80 |
+
out_wavs = out_wavs.to(input_device)
|
81 |
+
return out_wavs
|
82 |
+
|
83 |
+
def numpy_separate(self, wav: np.ndarray, **kwargs) -> np.ndarray:
|
84 |
+
""" Numpy interface to `separate`."""
|
85 |
+
wav = torch.from_numpy(wav)
|
86 |
+
out_wav = self.torch_separate(wav, **kwargs)
|
87 |
+
out_wav = out_wav.data.numpy()
|
88 |
+
return out_wav
|
89 |
+
|
90 |
+
def file_separate(
|
91 |
+
self, filename: str, output_dir=None, force_overwrite=False, **kwargs
|
92 |
+
) -> None:
|
93 |
+
""" Filename interface to `separate`."""
|
94 |
+
import soundfile as sf
|
95 |
+
|
96 |
+
wav, fs = sf.read(filename, dtype="float32", always_2d=True)
|
97 |
+
# FIXME: support only single-channel files for now.
|
98 |
+
to_save = self.numpy_separate(wav[:, 0], **kwargs)
|
99 |
+
|
100 |
+
# Save wav files to filename_est1.wav etc...
|
101 |
+
for src_idx, est_src in enumerate(to_save):
|
102 |
+
base = ".".join(filename.split(".")[:-1])
|
103 |
+
save_name = base + "_est{}.".format(src_idx + 1) + filename.split(".")[-1]
|
104 |
+
if os.path.isfile(save_name) and not force_overwrite:
|
105 |
+
warnings.warn(
|
106 |
+
f"File {save_name} already exists, pass `force_overwrite=True` to overwrite it",
|
107 |
+
UserWarning,
|
108 |
+
)
|
109 |
+
return
|
110 |
+
if output_dir is not None:
|
111 |
+
save_name = os.path.join(output_dir, save_name.split("/")[-1])
|
112 |
+
sf.write(save_name, est_src, fs)
|
113 |
+
|
114 |
+
def _separate(self, wav, *args, **kwargs):
|
115 |
+
"""Hidden separation method
|
116 |
+
|
117 |
+
Args:
|
118 |
+
wav (Union[torch.Tensor, numpy.ndarray, str]): waveform array/tensor.
|
119 |
+
Shape: 1D, 2D or 3D tensor, time last.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
The output of self(wav, *args, **kwargs).
|
123 |
+
"""
|
124 |
+
return self(wav, *args, **kwargs)
|
125 |
+
|
126 |
+
@classmethod
|
127 |
+
def from_pretrained(cls, pretrained_model_conf_or_path, *args, **kwargs):
|
128 |
+
"""Instantiate separation model from a model config (file or dict).
|
129 |
+
|
130 |
+
Args:
|
131 |
+
pretrained_model_conf_or_path (Union[dict, str]): model conf as
|
132 |
+
returned by `serialize`, or path to it. Need to contain
|
133 |
+
`model_args` and `state_dict` keys.
|
134 |
+
*args: Positional arguments to be passed to the model.
|
135 |
+
**kwargs: Keyword arguments to be passed to the model.
|
136 |
+
They overwrite the ones in the model package.
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
nn.Module corresponding to the pretrained model conf/URL.
|
140 |
+
|
141 |
+
Raises:
|
142 |
+
ValueError if the input config file doesn't contain the keys
|
143 |
+
`model_name`, `model_args` or `state_dict`.
|
144 |
+
"""
|
145 |
+
from . import get # Avoid circular imports
|
146 |
+
|
147 |
+
if isinstance(pretrained_model_conf_or_path, str):
|
148 |
+
# cached_model = self.cached_download(pretrained_model_conf_or_path)
|
149 |
+
if os.path.isfile(pretrained_model_conf_or_path):
|
150 |
+
cached_model = pretrained_model_conf_or_path
|
151 |
+
else:
|
152 |
+
raise ValueError(
|
153 |
+
"Model {} is not a file or doesn't exist.".format(pretrained_model_conf_or_path)
|
154 |
+
)
|
155 |
+
|
156 |
+
conf = torch.load(cached_model, map_location="cpu")
|
157 |
+
else:
|
158 |
+
conf = pretrained_model_conf_or_path
|
159 |
+
|
160 |
+
if "model_name" not in conf.keys():
|
161 |
+
raise ValueError(
|
162 |
+
"Expected config dictionary to have field "
|
163 |
+
"model_name`. Found only: {}".format(conf.keys())
|
164 |
+
)
|
165 |
+
if "state_dict" not in conf.keys():
|
166 |
+
raise ValueError(
|
167 |
+
"Expected config dictionary to have field "
|
168 |
+
"state_dict`. Found only: {}".format(conf.keys())
|
169 |
+
)
|
170 |
+
if "model_args" not in conf.keys():
|
171 |
+
raise ValueError(
|
172 |
+
"Expected config dictionary to have field "
|
173 |
+
"model_args`. Found only: {}".format(conf.keys())
|
174 |
+
)
|
175 |
+
conf["model_args"].update(kwargs) # kwargs overwrite config.
|
176 |
+
# Attempt to find the model and instantiate it.
|
177 |
+
try:
|
178 |
+
model_class = get(conf["model_name"])
|
179 |
+
except ValueError: # Couldn't get the model, maybe custom.
|
180 |
+
model = cls(*args, **conf["model_args"]) # Child class.
|
181 |
+
else:
|
182 |
+
model = model_class(*args, **conf["model_args"])
|
183 |
+
model.load_state_dict(conf["state_dict"])
|
184 |
+
return model
|
185 |
+
|
186 |
+
def serialize(self):
|
187 |
+
"""Serialize model and output dictionary.
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
dict, serialized model with keys `model_args` and `state_dict`.
|
191 |
+
"""
|
192 |
+
import pytorch_lightning as pl # Not used in torch.hub
|
193 |
+
|
194 |
+
from .. import __version__ as asteroid_version # Avoid circular imports
|
195 |
+
|
196 |
+
model_conf = dict(
|
197 |
+
model_name=self.__class__.__name__,
|
198 |
+
state_dict=self.get_state_dict(),
|
199 |
+
model_args=self.get_model_args(),
|
200 |
+
)
|
201 |
+
# Additional infos
|
202 |
+
infos = dict()
|
203 |
+
infos["software_versions"] = dict(
|
204 |
+
torch_version=torch.__version__,
|
205 |
+
pytorch_lightning_version=pl.__version__,
|
206 |
+
asteroid_version=asteroid_version,
|
207 |
+
)
|
208 |
+
model_conf["infos"] = infos
|
209 |
+
return model_conf
|
210 |
+
|
211 |
+
def get_state_dict(self):
|
212 |
+
""" In case the state dict needs to be modified before sharing the model."""
|
213 |
+
return self.state_dict()
|
214 |
+
|
215 |
+
def get_model_args(self):
|
216 |
+
raise NotImplementedError
|
217 |
+
|
218 |
+
def cached_download(self, filename_or_url):
|
219 |
+
if os.path.isfile(filename_or_url):
|
220 |
+
print("is file")
|
221 |
+
return filename_or_url
|
222 |
+
else:
|
223 |
+
print("Model {} is not a file or doesn't exist.".format(filename_or_url))
|
224 |
+
|
225 |
+
|
226 |
+
class BaseEncoderMaskerDecoder(BaseModel):
|
227 |
+
"""Base class for encoder-masker-decoder separation models.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
encoder (Encoder): Encoder instance.
|
231 |
+
masker (nn.Module): masker network.
|
232 |
+
decoder (Decoder): Decoder instance.
|
233 |
+
encoder_activation (Optional[str], optional): Activation to apply after encoder.
|
234 |
+
See ``asteroid.masknn.activations`` for valid values.
|
235 |
+
"""
|
236 |
+
|
237 |
+
def __init__(self, encoder, masker, decoder, encoder_activation=None):
|
238 |
+
super().__init__()
|
239 |
+
self.encoder = encoder
|
240 |
+
self.masker = masker
|
241 |
+
self.decoder = decoder
|
242 |
+
|
243 |
+
self.encoder_activation = encoder_activation
|
244 |
+
self.enc_activation = activations.get(encoder_activation or "linear")()
|
245 |
+
|
246 |
+
def forward(self, wav):
|
247 |
+
"""Enc/Mask/Dec model forward
|
248 |
+
|
249 |
+
Args:
|
250 |
+
wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
|
254 |
+
"""
|
255 |
+
# Handle 1D, 2D or n-D inputs
|
256 |
+
was_one_d = wav.ndim == 1
|
257 |
+
# Reshape to (batch, n_mix, time)
|
258 |
+
wav = _unsqueeze_to_3d(wav)
|
259 |
+
|
260 |
+
# Real forward
|
261 |
+
tf_rep = self.encoder(wav)
|
262 |
+
tf_rep = self.postprocess_encoded(tf_rep)
|
263 |
+
tf_rep = self.enc_activation(tf_rep)
|
264 |
+
|
265 |
+
est_masks = self.masker(tf_rep)
|
266 |
+
est_masks = self.postprocess_masks(est_masks)
|
267 |
+
|
268 |
+
masked_tf_rep = est_masks * tf_rep.unsqueeze(1)
|
269 |
+
masked_tf_rep = self.postprocess_masked(masked_tf_rep)
|
270 |
+
|
271 |
+
decoded = self.decoder(masked_tf_rep)
|
272 |
+
decoded = self.postprocess_decoded(decoded)
|
273 |
+
|
274 |
+
reconstructed = pad_x_to_y(decoded, wav)
|
275 |
+
if was_one_d:
|
276 |
+
return reconstructed.squeeze(0)
|
277 |
+
else:
|
278 |
+
return reconstructed
|
279 |
+
|
280 |
+
def postprocess_encoded(self, tf_rep):
|
281 |
+
"""Hook to perform transformations on the encoded, time-frequency domain
|
282 |
+
representation (output of the encoder) before encoder activation is applied.
|
283 |
+
|
284 |
+
Args:
|
285 |
+
tf_rep (Tensor of shape (batch, freq, time)):
|
286 |
+
Output of the encoder, before encoder activation is applied.
|
287 |
+
|
288 |
+
Return:
|
289 |
+
Transformed `tf_rep`
|
290 |
+
"""
|
291 |
+
return tf_rep
|
292 |
+
|
293 |
+
def postprocess_masks(self, masks):
|
294 |
+
"""Hook to perform transformations on the masks (output of the masker) before
|
295 |
+
masks are applied.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
masks (Tensor of shape (batch, n_src, freq, time)):
|
299 |
+
Output of the masker
|
300 |
+
|
301 |
+
Return:
|
302 |
+
Transformed `masks`
|
303 |
+
"""
|
304 |
+
return masks
|
305 |
+
|
306 |
+
def postprocess_masked(self, masked_tf_rep):
|
307 |
+
"""Hook to perform transformations on the masked time-frequency domain
|
308 |
+
representation (result of masking in the time-frequency domain) before decoding.
|
309 |
+
|
310 |
+
Args:
|
311 |
+
masked_tf_rep (Tensor of shape (batch, n_src, freq, time)):
|
312 |
+
Masked time-frequency representation, before decoding.
|
313 |
+
|
314 |
+
Return:
|
315 |
+
Transformed `masked_tf_rep`
|
316 |
+
"""
|
317 |
+
return masked_tf_rep
|
318 |
+
|
319 |
+
def postprocess_decoded(self, decoded):
|
320 |
+
"""Hook to perform transformations on the decoded, time domain representation
|
321 |
+
(output of the decoder) before original shape reconstruction.
|
322 |
+
|
323 |
+
Args:
|
324 |
+
decoded (Tensor of shape (batch, n_src, time)):
|
325 |
+
Output of the decoder, before original shape reconstruction.
|
326 |
+
|
327 |
+
Return:
|
328 |
+
Transformed `decoded`
|
329 |
+
"""
|
330 |
+
return decoded
|
331 |
+
|
332 |
+
def get_model_args(self):
|
333 |
+
""" Arguments needed to re-instantiate the model. """
|
334 |
+
fb_config = self.encoder.filterbank.get_config()
|
335 |
+
masknet_config = self.masker.get_config()
|
336 |
+
# Assert both dict are disjoint
|
337 |
+
if not all(k not in fb_config for k in masknet_config):
|
338 |
+
raise AssertionError(
|
339 |
+
"Filterbank and Mask network config share" "common keys. Merging them is not safe."
|
340 |
+
)
|
341 |
+
# Merge all args under model_args.
|
342 |
+
model_args = {
|
343 |
+
**fb_config,
|
344 |
+
**masknet_config,
|
345 |
+
"encoder_activation": self.encoder_activation,
|
346 |
+
}
|
347 |
+
return model_args
|
348 |
+
|
349 |
+
|
350 |
+
# Backwards compatibility
|
351 |
+
BaseTasNet = BaseEncoderMaskerDecoder
|
DPTNet_eval/asteroid_test/models/dptnet.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..filterbanks import make_enc_dec
|
2 |
+
from ..masknn import DPTransformer
|
3 |
+
from .base_models import BaseEncoderMaskerDecoder
|
4 |
+
|
5 |
+
|
6 |
+
class DPTNet(BaseEncoderMaskerDecoder):
|
7 |
+
"""DPTNet separation model, as described in [1].
|
8 |
+
|
9 |
+
Args:
|
10 |
+
n_src (int): Number of masks to estimate.
|
11 |
+
out_chan (int or None): Number of bins in the estimated masks.
|
12 |
+
Defaults to `in_chan`.
|
13 |
+
bn_chan (int): Number of channels after the bottleneck.
|
14 |
+
Defaults to 128.
|
15 |
+
hid_size (int): Number of neurons in the RNNs cell state.
|
16 |
+
Defaults to 128.
|
17 |
+
chunk_size (int): window size of overlap and add processing.
|
18 |
+
Defaults to 100.
|
19 |
+
hop_size (int or None): hop size (stride) of overlap and add processing.
|
20 |
+
Default to `chunk_size // 2` (50% overlap).
|
21 |
+
n_repeats (int): Number of repeats. Defaults to 6.
|
22 |
+
norm_type (str, optional): Type of normalization to use. To choose from
|
23 |
+
|
24 |
+
- ``'gLN'``: global Layernorm
|
25 |
+
- ``'cLN'``: channelwise Layernorm
|
26 |
+
mask_act (str, optional): Which non-linear function to generate mask.
|
27 |
+
bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN
|
28 |
+
(Intra-Chunk is always bidirectional).
|
29 |
+
rnn_type (str, optional): Type of RNN used. Choose between ``'RNN'``,
|
30 |
+
``'LSTM'`` and ``'GRU'``.
|
31 |
+
num_layers (int, optional): Number of layers in each RNN.
|
32 |
+
dropout (float, optional): Dropout ratio, must be in [0,1].
|
33 |
+
in_chan (int, optional): Number of input channels, should be equal to
|
34 |
+
n_filters.
|
35 |
+
fb_name (str, className): Filterbank family from which to make encoder
|
36 |
+
and decoder. To choose among [``'free'``, ``'analytic_free'``,
|
37 |
+
``'param_sinc'``, ``'stft'``].
|
38 |
+
n_filters (int): Number of filters / Input dimension of the masker net.
|
39 |
+
kernel_size (int): Length of the filters.
|
40 |
+
stride (int, optional): Stride of the convolution.
|
41 |
+
If None (default), set to ``kernel_size // 2``.
|
42 |
+
**fb_kwargs (dict): Additional kwards to pass to the filterbank
|
43 |
+
creation.
|
44 |
+
|
45 |
+
References:
|
46 |
+
[1]: Jingjing Chen et al. "Dual-Path Transformer Network: Direct
|
47 |
+
Context-Aware Modeling for End-to-End Monaural Speech Separation"
|
48 |
+
Interspeech 2020.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
n_src,
|
54 |
+
ff_hid=256,
|
55 |
+
chunk_size=100,
|
56 |
+
hop_size=None,
|
57 |
+
n_repeats=6,
|
58 |
+
norm_type="gLN",
|
59 |
+
ff_activation="relu",
|
60 |
+
encoder_activation="relu",
|
61 |
+
mask_act="relu",
|
62 |
+
bidirectional=True,
|
63 |
+
dropout=0,
|
64 |
+
in_chan=None,
|
65 |
+
fb_name="free",
|
66 |
+
kernel_size=16,
|
67 |
+
n_filters=64,
|
68 |
+
stride=8,
|
69 |
+
**fb_kwargs,
|
70 |
+
):
|
71 |
+
encoder, decoder = make_enc_dec(
|
72 |
+
fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, **fb_kwargs
|
73 |
+
)
|
74 |
+
n_feats = encoder.n_feats_out
|
75 |
+
if in_chan is not None:
|
76 |
+
assert in_chan == n_feats, (
|
77 |
+
"Number of filterbank output channels"
|
78 |
+
" and number of input channels should "
|
79 |
+
"be the same. Received "
|
80 |
+
f"{n_feats} and {in_chan}"
|
81 |
+
)
|
82 |
+
# Update in_chan
|
83 |
+
masker = DPTransformer(
|
84 |
+
n_feats,
|
85 |
+
n_src,
|
86 |
+
ff_hid=ff_hid,
|
87 |
+
ff_activation=ff_activation,
|
88 |
+
chunk_size=chunk_size,
|
89 |
+
hop_size=hop_size,
|
90 |
+
n_repeats=n_repeats,
|
91 |
+
norm_type=norm_type,
|
92 |
+
mask_act=mask_act,
|
93 |
+
bidirectional=bidirectional,
|
94 |
+
dropout=dropout,
|
95 |
+
)
|
96 |
+
super().__init__(encoder, masker, decoder, encoder_activation=encoder_activation)
|
DPTNet_eval/asteroid_test/utils/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .torch_utils import tensors_to_device, to_cuda
|
2 |
+
|
3 |
+
# The functions above were all in asteroid/utils.py before refactoring into
|
4 |
+
# asteroid/utils/*_utils.py files. They are imported for backward compatibility.
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
"tensors_to_device",
|
8 |
+
"to_cuda",
|
9 |
+
]
|
DPTNet_eval/asteroid_test/utils/torch_utils.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
|
6 |
+
def to_cuda(tensors): # pragma: no cover (No CUDA on travis)
|
7 |
+
"""Transfer tensor, dict or list of tensors to GPU.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
tensors (:class:`torch.Tensor`, list or dict): May be a single, a
|
11 |
+
list or a dictionary of tensors.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
:class:`torch.Tensor`:
|
15 |
+
Same as input but transferred to cuda. Goes through lists and dicts
|
16 |
+
and transfers the torch.Tensor to cuda. Leaves the rest untouched.
|
17 |
+
"""
|
18 |
+
if isinstance(tensors, torch.Tensor):
|
19 |
+
return tensors.cuda()
|
20 |
+
if isinstance(tensors, list):
|
21 |
+
return [to_cuda(tens) for tens in tensors]
|
22 |
+
if isinstance(tensors, dict):
|
23 |
+
for key in tensors.keys():
|
24 |
+
tensors[key] = to_cuda(tensors[key])
|
25 |
+
return tensors
|
26 |
+
raise TypeError(
|
27 |
+
"tensors must be a tensor or a list or dict of tensors. "
|
28 |
+
" Got tensors of type {}".format(type(tensors))
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def tensors_to_device(tensors, device):
|
33 |
+
"""Transfer tensor, dict or list of tensors to device.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
tensors (:class:`torch.Tensor`): May be a single, a list or a
|
37 |
+
dictionary of tensors.
|
38 |
+
device (:class: `torch.device`): the device where to place the tensors.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
Union [:class:`torch.Tensor`, list, tuple, dict]:
|
42 |
+
Same as input but transferred to device.
|
43 |
+
Goes through lists and dicts and transfers the torch.Tensor to
|
44 |
+
device. Leaves the rest untouched.
|
45 |
+
"""
|
46 |
+
if isinstance(tensors, torch.Tensor):
|
47 |
+
return tensors.to(device)
|
48 |
+
elif isinstance(tensors, (list, tuple)):
|
49 |
+
return [tensors_to_device(tens, device) for tens in tensors]
|
50 |
+
elif isinstance(tensors, dict):
|
51 |
+
for key in tensors.keys():
|
52 |
+
tensors[key] = tensors_to_device(tensors[key], device)
|
53 |
+
return tensors
|
54 |
+
else:
|
55 |
+
return tensors
|
56 |
+
|
57 |
+
|
58 |
+
def pad_x_to_y(x, y, axis=-1):
|
59 |
+
"""Pad first argument to have same size as second argument
|
60 |
+
|
61 |
+
Args:
|
62 |
+
x (torch.Tensor): Tensor to be padded.
|
63 |
+
y (torch.Tensor): Tensor to pad x to.
|
64 |
+
axis (int): Axis to pad on.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
torch.Tensor, x padded to match y's shape.
|
68 |
+
"""
|
69 |
+
if axis != -1:
|
70 |
+
raise NotImplementedError
|
71 |
+
inp_len = y.size(axis)
|
72 |
+
output_len = x.size(axis)
|
73 |
+
return nn.functional.pad(x, [0, inp_len - output_len])
|
74 |
+
|
75 |
+
|
76 |
+
def load_state_dict_in(state_dict, model):
|
77 |
+
"""Strictly loads state_dict in model, or the next submodel.
|
78 |
+
Useful to load standalone model after training it with System.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
state_dict (OrderedDict): the state_dict to load.
|
82 |
+
model (torch.nn.Module): the model to load it into
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
torch.nn.Module: model with loaded weights.
|
86 |
+
|
87 |
+
# .. note:: Keys in a state_dict look like object1.object2.layer_name.weight.etc
|
88 |
+
We first try to load the model in the classic way.
|
89 |
+
If this fail we removes the first left part of the key to obtain
|
90 |
+
object2.layer_name.weight.etc.
|
91 |
+
Blindly loading with strictly=False should be done with some logging
|
92 |
+
of the missing keys in the state_dict and the model.
|
93 |
+
|
94 |
+
"""
|
95 |
+
try:
|
96 |
+
# This can fail if the model was included into a bigger nn.Module
|
97 |
+
# object. For example, into System.
|
98 |
+
model.load_state_dict(state_dict, strict=True)
|
99 |
+
except RuntimeError:
|
100 |
+
# keys look like object1.object2.layer_name.weight.etc
|
101 |
+
# The following will remove the first left part of the key to obtain
|
102 |
+
# object2.layer_name.weight.etc.
|
103 |
+
# Blindly loading with strictly=False should be done with some
|
104 |
+
# new_state_dict of the missing keys in the state_dict and the model.
|
105 |
+
new_state_dict = OrderedDict()
|
106 |
+
for k, v in state_dict.items():
|
107 |
+
new_k = k[k.find(".") + 1 :]
|
108 |
+
new_state_dict[new_k] = v
|
109 |
+
model.load_state_dict(new_state_dict, strict=True)
|
110 |
+
return model
|
111 |
+
|
112 |
+
|
113 |
+
def are_models_equal(model1, model2):
|
114 |
+
"""Check for weights equality between models.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
model1 (nn.Module): model instance to be compared.
|
118 |
+
model2 (nn.Module): second model instance to be compared.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
bool: Whether all model weights are equal.
|
122 |
+
"""
|
123 |
+
for p1, p2 in zip(model1.parameters(), model2.parameters()):
|
124 |
+
if p1.data.ne(p2.data).sum() > 0:
|
125 |
+
return False
|
126 |
+
return True
|