HoneyTian commited on
Commit
1d4c9c3
·
1 Parent(s): 6cd307e

add frcrn model

Browse files
examples/frcrn/step_2_train_model.py CHANGED
@@ -1,5 +1,13 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
 
 
 
 
 
 
 
 
3
  import argparse
4
  import json
5
  import logging
@@ -163,7 +171,7 @@ def main():
163
  model.train()
164
 
165
  # optimizer
166
- logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
167
  optimizer = torch.optim.AdamW(model.get_params(weight_decay=config.weight_decay), config.lr)
168
 
169
  # resume training
@@ -217,8 +225,7 @@ def main():
217
  average_pesq_score = 1000000000
218
  average_loss = 1000000000
219
  average_neg_si_snr_loss = 1000000000
220
- average_mag_loss = 1000000000
221
- average_pha_loss = 1000000000
222
 
223
  model_list = list()
224
  best_epoch_idx = None
@@ -236,8 +243,7 @@ def main():
236
  total_pesq_score = 0.
237
  total_loss = 0.
238
  total_neg_si_snr_loss = 0.
239
- total_map_loss = 0.
240
- total_pha_loss = 0.
241
  total_batches = 0.
242
 
243
  progress_bar_train = tqdm(
@@ -253,9 +259,9 @@ def main():
253
  denoise_audios = est_wav
254
 
255
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
256
- map_loss, pha_loss = model.mag_pha_loss_fn(est_mask, clean_audios, noisy_audios)
257
 
258
- loss = 0.5 * map_loss + 0.5 * pha_loss + 0.5 * neg_si_snr_loss
259
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
260
  logger.info(f"find nan or inf in loss.")
261
  continue
@@ -273,15 +279,13 @@ def main():
273
  total_pesq_score += pesq_score
274
  total_loss += loss.item()
275
  total_neg_si_snr_loss += neg_si_snr_loss.item()
276
- total_map_loss += map_loss.item()
277
- total_pha_loss += pha_loss.item()
278
  total_batches += 1
279
 
280
  average_pesq_score = round(total_pesq_score / total_batches, 4)
281
  average_loss = round(total_loss / total_batches, 4)
282
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
283
- average_mag_loss = round(total_map_loss / total_batches, 4)
284
- average_pha_loss = round(total_pha_loss / total_batches, 4)
285
 
286
  progress_bar_train.update(1)
287
  progress_bar_train.set_postfix({
@@ -289,8 +293,7 @@ def main():
289
  "pesq_score": average_pesq_score,
290
  "loss": average_loss,
291
  "neg_si_snr_loss": average_neg_si_snr_loss,
292
- "mag_loss": average_mag_loss,
293
- "pha_loss": average_pha_loss,
294
  })
295
 
296
  # evaluation
@@ -302,8 +305,7 @@ def main():
302
  total_pesq_score = 0.
303
  total_loss = 0.
304
  total_neg_si_snr_loss = 0.
305
- total_map_loss = 0.
306
- total_pha_loss = 0.
307
  total_batches = 0.
308
 
309
  progress_bar_train.close()
@@ -319,9 +321,9 @@ def main():
319
  denoise_audios = est_wav
320
 
321
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
322
- map_loss, pha_loss = model.mag_pha_loss_fn(est_mask, clean_audios, noisy_audios)
323
 
324
- loss = 0.5 * map_loss + 0.5 * pha_loss + 0.5 * neg_si_snr_loss
325
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
326
  logger.info(f"find nan or inf in loss.")
327
  continue
@@ -333,15 +335,13 @@ def main():
333
  total_pesq_score += pesq_score
334
  total_loss += loss.item()
335
  total_neg_si_snr_loss += neg_si_snr_loss.item()
336
- total_map_loss += map_loss.item()
337
- total_pha_loss += pha_loss.item()
338
  total_batches += 1
339
 
340
  average_pesq_score = round(total_pesq_score / total_batches, 4)
341
  average_loss = round(total_loss / total_batches, 4)
342
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
343
- average_mag_loss = round(total_map_loss / total_batches, 4)
344
- average_pha_loss = round(total_pha_loss / total_batches, 4)
345
 
346
  progress_bar_eval.update(1)
347
  progress_bar_eval.set_postfix({
@@ -349,15 +349,13 @@ def main():
349
  "pesq_score": average_pesq_score,
350
  "loss": average_loss,
351
  "neg_si_snr_loss": average_neg_si_snr_loss,
352
- "mag_loss": average_mag_loss,
353
- "pha_loss": average_pha_loss,
354
  })
355
 
356
  total_pesq_score = 0.
357
  total_loss = 0.
358
  total_neg_si_snr_loss = 0.
359
- total_map_loss = 0.
360
- total_pha_loss = 0.
361
  total_batches = 0.
362
 
363
  progress_bar_eval.close()
@@ -402,8 +400,7 @@ def main():
402
  "pesq_score": average_pesq_score,
403
  "loss": average_loss,
404
  "neg_si_snr_loss": average_neg_si_snr_loss,
405
- "mag_loss": average_mag_loss,
406
- "pha_loss": average_pha_loss,
407
  }
408
  metrics_filename = save_dir / "metrics_epoch.json"
409
  with open(metrics_filename, "w", encoding="utf-8") as f:
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
+ """
4
+ FRCRN 论文中:
5
+ 在 WSJ0 数据集上训练了 120 个 epoch 得到 pesq 3.62, stoi 98.24, si-snr 21.33
6
+
7
+ WSJ0 包含约 80小时的纯净英语语音录音.
8
+
9
+ 我的音频大约是 1300 小时, 则预期大约需要 10个 epoch
10
+ """
11
  import argparse
12
  import json
13
  import logging
 
171
  model.train()
172
 
173
  # optimizer
174
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
175
  optimizer = torch.optim.AdamW(model.get_params(weight_decay=config.weight_decay), config.lr)
176
 
177
  # resume training
 
225
  average_pesq_score = 1000000000
226
  average_loss = 1000000000
227
  average_neg_si_snr_loss = 1000000000
228
+ average_mask_loss = 1000000000
 
229
 
230
  model_list = list()
231
  best_epoch_idx = None
 
243
  total_pesq_score = 0.
244
  total_loss = 0.
245
  total_neg_si_snr_loss = 0.
246
+ total_mask_loss = 0.
 
247
  total_batches = 0.
248
 
249
  progress_bar_train = tqdm(
 
259
  denoise_audios = est_wav
260
 
261
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
262
+ mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
263
 
264
+ loss = 1.0 * neg_si_snr_loss + 1.0 * mask_loss
265
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
266
  logger.info(f"find nan or inf in loss.")
267
  continue
 
279
  total_pesq_score += pesq_score
280
  total_loss += loss.item()
281
  total_neg_si_snr_loss += neg_si_snr_loss.item()
282
+ total_mask_loss += mask_loss.item()
 
283
  total_batches += 1
284
 
285
  average_pesq_score = round(total_pesq_score / total_batches, 4)
286
  average_loss = round(total_loss / total_batches, 4)
287
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
288
+ average_mask_loss = round(total_mask_loss / total_batches, 4)
 
289
 
290
  progress_bar_train.update(1)
291
  progress_bar_train.set_postfix({
 
293
  "pesq_score": average_pesq_score,
294
  "loss": average_loss,
295
  "neg_si_snr_loss": average_neg_si_snr_loss,
296
+ "mask_loss": average_mask_loss,
 
297
  })
298
 
299
  # evaluation
 
305
  total_pesq_score = 0.
306
  total_loss = 0.
307
  total_neg_si_snr_loss = 0.
308
+ total_mask_loss = 0.
 
309
  total_batches = 0.
310
 
311
  progress_bar_train.close()
 
321
  denoise_audios = est_wav
322
 
323
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
324
+ mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
325
 
326
+ loss = 1.0 * neg_si_snr_loss + 1.0 * mask_loss
327
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
328
  logger.info(f"find nan or inf in loss.")
329
  continue
 
335
  total_pesq_score += pesq_score
336
  total_loss += loss.item()
337
  total_neg_si_snr_loss += neg_si_snr_loss.item()
338
+ total_mask_loss += mask_loss.item()
 
339
  total_batches += 1
340
 
341
  average_pesq_score = round(total_pesq_score / total_batches, 4)
342
  average_loss = round(total_loss / total_batches, 4)
343
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
344
+ average_mask_loss = round(total_mask_loss / total_batches, 4)
 
345
 
346
  progress_bar_eval.update(1)
347
  progress_bar_eval.set_postfix({
 
349
  "pesq_score": average_pesq_score,
350
  "loss": average_loss,
351
  "neg_si_snr_loss": average_neg_si_snr_loss,
352
+ "mask_loss": average_mask_loss,
 
353
  })
354
 
355
  total_pesq_score = 0.
356
  total_loss = 0.
357
  total_neg_si_snr_loss = 0.
358
+ total_mask_loss = 0.
 
359
  total_batches = 0.
360
 
361
  progress_bar_eval.close()
 
400
  "pesq_score": average_pesq_score,
401
  "loss": average_loss,
402
  "neg_si_snr_loss": average_neg_si_snr_loss,
403
+ "mask_loss": average_mask_loss,
 
404
  }
405
  metrics_filename = save_dir / "metrics_epoch.json"
406
  with open(metrics_filename, "w", encoding="utf-8") as f:
examples/frcrn/yaml/config.yaml CHANGED
@@ -30,4 +30,4 @@ max_snr_db: 20
30
 
31
  num_workers: 8
32
  batch_size: 32
33
- eval_steps: 25000
 
30
 
31
  num_workers: 8
32
  batch_size: 32
33
+ eval_steps: 10000
toolbox/torchaudio/models/dfnet/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/models/dfnet/configuration_dfnet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from typing import Tuple
4
+
5
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class DfNetConfig(PretrainedConfig):
9
+ def __init__(self,
10
+ sample_rate: int = 8000,
11
+ nfft: int = 512,
12
+ win_size: int = 200,
13
+ hop_size: int = 80,
14
+ win_type: str = "hann",
15
+
16
+ spec_bins: int = 256,
17
+
18
+ conv_channels: int = 64,
19
+ conv_kernel_size_input: Tuple[int, int] = (3, 3),
20
+ conv_kernel_size_inner: Tuple[int, int] = (1, 3),
21
+ conv_lookahead: int = 0,
22
+
23
+ convt_kernel_size_inner: Tuple[int, int] = (1, 3),
24
+
25
+ embedding_hidden_size: int = 256,
26
+ encoder_combine_op: str = "concat",
27
+
28
+ encoder_emb_skip_op: str = "none",
29
+ encoder_emb_linear_groups: int = 16,
30
+ encoder_emb_hidden_size: int = 256,
31
+
32
+ encoder_linear_groups: int = 32,
33
+
34
+ lsnr_max: int = 30,
35
+ lsnr_min: int = -15,
36
+ norm_tau: float = 1.,
37
+
38
+ decoder_emb_num_layers: int = 3,
39
+ decoder_emb_skip_op: str = "none",
40
+ decoder_emb_linear_groups: int = 16,
41
+ decoder_emb_hidden_size: int = 256,
42
+
43
+ df_decoder_hidden_size: int = 256,
44
+ df_num_layers: int = 2,
45
+ df_order: int = 5,
46
+ df_bins: int = 96,
47
+ df_gru_skip: str = "grouped_linear",
48
+ df_decoder_linear_groups: int = 16,
49
+ df_pathway_kernel_size_t: int = 5,
50
+ df_lookahead: int = 2,
51
+
52
+ use_post_filter: bool = False,
53
+ **kwargs
54
+ ):
55
+ super(DfNetConfig, self).__init__(**kwargs)
56
+ # transform
57
+ self.sample_rate = sample_rate
58
+ self.nfft = nfft
59
+ self.win_size = win_size
60
+ self.hop_size = hop_size
61
+ self.win_type = win_type
62
+
63
+ # spectrum
64
+ self.spec_bins = spec_bins
65
+
66
+ # conv
67
+ self.conv_channels = conv_channels
68
+ self.conv_kernel_size_input = conv_kernel_size_input
69
+ self.conv_kernel_size_inner = conv_kernel_size_inner
70
+ self.conv_lookahead = conv_lookahead
71
+
72
+ self.convt_kernel_size_inner = convt_kernel_size_inner
73
+
74
+ self.embedding_hidden_size = embedding_hidden_size
75
+
76
+ # encoder
77
+ self.encoder_emb_skip_op = encoder_emb_skip_op
78
+ self.encoder_emb_linear_groups = encoder_emb_linear_groups
79
+ self.encoder_emb_hidden_size = encoder_emb_hidden_size
80
+
81
+ self.encoder_linear_groups = encoder_linear_groups
82
+ self.encoder_combine_op = encoder_combine_op
83
+
84
+ self.lsnr_max = lsnr_max
85
+ self.lsnr_min = lsnr_min
86
+ self.norm_tau = norm_tau
87
+
88
+ # decoder
89
+ self.decoder_emb_num_layers = decoder_emb_num_layers
90
+ self.decoder_emb_skip_op = decoder_emb_skip_op
91
+ self.decoder_emb_linear_groups = decoder_emb_linear_groups
92
+ self.decoder_emb_hidden_size = decoder_emb_hidden_size
93
+
94
+ # df decoder
95
+ self.df_decoder_hidden_size = df_decoder_hidden_size
96
+ self.df_num_layers = df_num_layers
97
+ self.df_order = df_order
98
+ self.df_bins = df_bins
99
+ self.df_gru_skip = df_gru_skip
100
+ self.df_decoder_linear_groups = df_decoder_linear_groups
101
+ self.df_pathway_kernel_size_t = df_pathway_kernel_size_t
102
+ self.df_lookahead = df_lookahead
103
+
104
+ # runtime
105
+ self.use_post_filter = use_post_filter
106
+
107
+
108
+ if __name__ == "__main__":
109
+ pass
toolbox/torchaudio/models/dfnet/conv_stft.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/conv_stft.py
5
+ """
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from scipy.signal import get_window
11
+
12
+
13
+ def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False):
14
+ if win_type == "None" or win_type is None:
15
+ window = np.ones(win_size)
16
+ else:
17
+ window = get_window(win_type, win_size, fftbins=True)**0.5
18
+
19
+ fourier_basis = np.fft.rfft(np.eye(nfft))[:win_size]
20
+ real_kernel = np.real(fourier_basis)
21
+ image_kernel = np.imag(fourier_basis)
22
+ kernel = np.concatenate([real_kernel, image_kernel], 1).T
23
+
24
+ if inverse:
25
+ kernel = np.linalg.pinv(kernel).T
26
+
27
+ kernel = kernel * window
28
+ kernel = kernel[:, None, :]
29
+ result = (
30
+ torch.from_numpy(kernel.astype(np.float32)),
31
+ torch.from_numpy(window[None, :, None].astype(np.float32))
32
+ )
33
+ return result
34
+
35
+
36
+ class ConvSTFT(nn.Module):
37
+
38
+ def __init__(self,
39
+ nfft: int,
40
+ win_size: int,
41
+ hop_size: int,
42
+ win_type: str = "hamming",
43
+ feature_type: str = "real",
44
+ requires_grad: bool = False):
45
+ super(ConvSTFT, self).__init__()
46
+
47
+ if nfft is None:
48
+ self.nfft = int(2**np.ceil(np.log2(win_size)))
49
+ else:
50
+ self.nfft = nfft
51
+
52
+ kernel, _ = init_kernels(self.nfft, win_size, hop_size, win_type)
53
+ self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
54
+
55
+ self.win_size = win_size
56
+ self.hop_size = hop_size
57
+
58
+ self.stride = hop_size
59
+ self.dim = self.nfft
60
+ self.feature_type = feature_type
61
+
62
+ def forward(self, inputs: torch.Tensor):
63
+ if inputs.dim() == 2:
64
+ inputs = torch.unsqueeze(inputs, 1)
65
+
66
+ outputs = F.conv1d(inputs, self.weight, stride=self.stride)
67
+
68
+ if self.feature_type == "complex":
69
+ return outputs
70
+ else:
71
+ dim = self.dim // 2 + 1
72
+ real = outputs[:, :dim, :]
73
+ imag = outputs[:, dim:, :]
74
+ mags = torch.sqrt(real**2 + imag**2)
75
+ phase = torch.atan2(imag, real)
76
+ return mags, phase
77
+
78
+
79
+ class ConviSTFT(nn.Module):
80
+
81
+ def __init__(self,
82
+ win_size: int,
83
+ hop_size: int,
84
+ nfft: int = None,
85
+ win_type: str = "hamming",
86
+ feature_type: str = "real",
87
+ requires_grad: bool = False):
88
+ super(ConviSTFT, self).__init__()
89
+ if nfft is None:
90
+ self.nfft = int(2**np.ceil(np.log2(win_size)))
91
+ else:
92
+ self.nfft = nfft
93
+
94
+ kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True)
95
+ self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
96
+
97
+ self.win_size = win_size
98
+ self.hop_size = hop_size
99
+ self.win_type = win_type
100
+
101
+ self.stride = hop_size
102
+ self.dim = self.nfft
103
+ self.feature_type = feature_type
104
+
105
+ self.register_buffer("window", window)
106
+ self.register_buffer("enframe", torch.eye(win_size)[:, None, :])
107
+
108
+ def forward(self,
109
+ inputs: torch.Tensor,
110
+ phase: torch.Tensor = None):
111
+ """
112
+ :param inputs: torch.Tensor, shape: [b, n+2, t] (complex spec) or [b, n//2+1, t] (mags)
113
+ :param phase: torch.Tensor, shape: [b, n//2+1, t]
114
+ :return:
115
+ """
116
+ if phase is not None:
117
+ real = inputs * torch.cos(phase)
118
+ imag = inputs * torch.sin(phase)
119
+ inputs = torch.cat([real, imag], 1)
120
+ outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
121
+
122
+ # this is from torch-stft: https://github.com/pseeth/torch-stft
123
+ t = self.window.repeat(1, 1, inputs.size(-1))**2
124
+ coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
125
+ outputs = outputs / (coff + 1e-8)
126
+ return outputs
127
+
128
+
129
+ def main():
130
+ stft = ConvSTFT(win_size=512, hop_size=200, feature_type="complex")
131
+ istft = ConviSTFT(win_size=512, hop_size=200, feature_type="complex")
132
+
133
+ mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32)
134
+
135
+ spec = stft.forward(mixture)
136
+ # shape: [batch_size, freq_bins, time_steps]
137
+ print(spec.shape)
138
+
139
+ waveform = istft.forward(spec)
140
+ # shape: [batch_size, channels, num_samples]
141
+ print(waveform.shape)
142
+
143
+ return
144
+
145
+
146
+ if __name__ == "__main__":
147
+ main()
toolbox/torchaudio/models/dfnet/modeling_dfnet.py ADDED
@@ -0,0 +1,952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ import math
5
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torchaudio
11
+
12
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
13
+ from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
14
+ from toolbox.torchaudio.models.dfnet.conv_stft import ConvSTFT, ConviSTFT
15
+
16
+
17
+ MODEL_FILE = "model.pt"
18
+
19
+
20
+ norm_layer_dict = {
21
+ "batch_norm_2d": torch.nn.BatchNorm2d
22
+ }
23
+
24
+
25
+ activation_layer_dict = {
26
+ "relu": torch.nn.ReLU,
27
+ "identity": torch.nn.Identity,
28
+ "sigmoid": torch.nn.Sigmoid,
29
+ }
30
+
31
+
32
+ class CausalConv2d(nn.Sequential):
33
+ def __init__(self,
34
+ in_channels: int,
35
+ out_channels: int,
36
+ kernel_size: Union[int, Iterable[int]],
37
+ fstride: int = 1,
38
+ dilation: int = 1,
39
+ fpad: bool = True,
40
+ bias: bool = True,
41
+ separable: bool = False,
42
+ norm_layer: str = "batch_norm_2d",
43
+ activation_layer: str = "relu",
44
+ lookahead: int = 0
45
+ ):
46
+ """
47
+ Causal Conv2d by delaying the signal for any lookahead.
48
+
49
+ Expected input format: [batch_size, channels, time_steps, spec_dim]
50
+
51
+ :param in_channels:
52
+ :param out_channels:
53
+ :param kernel_size:
54
+ :param fstride:
55
+ :param dilation:
56
+ :param fpad:
57
+ """
58
+ super(CausalConv2d, self).__init__()
59
+ kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
60
+
61
+ if fpad:
62
+ fpad_ = kernel_size[1] // 2 + dilation - 1
63
+ else:
64
+ fpad_ = 0
65
+
66
+ # for last 2 dim, pad (left, right, top, bottom).
67
+ pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
68
+
69
+ layers = list()
70
+ if any(x > 0 for x in pad):
71
+ layers.append(nn.ConstantPad2d(pad, 0.0))
72
+
73
+ groups = math.gcd(in_channels, out_channels) if separable else 1
74
+ if groups == 1:
75
+ separable = False
76
+ if max(kernel_size) == 1:
77
+ separable = False
78
+
79
+ layers.append(
80
+ nn.Conv2d(
81
+ in_channels,
82
+ out_channels,
83
+ kernel_size=kernel_size,
84
+ padding=(0, fpad_),
85
+ stride=(1, fstride), # stride over time is always 1
86
+ dilation=(1, dilation), # dilation over time is always 1
87
+ groups=groups,
88
+ bias=bias,
89
+ )
90
+ )
91
+
92
+ if separable:
93
+ layers.append(
94
+ nn.Conv2d(
95
+ out_channels,
96
+ out_channels,
97
+ kernel_size=1,
98
+ bias=False,
99
+ )
100
+ )
101
+
102
+ if norm_layer is not None:
103
+ norm_layer = norm_layer_dict[norm_layer]
104
+ layers.append(norm_layer(out_channels))
105
+
106
+ if activation_layer is not None:
107
+ activation_layer = activation_layer_dict[activation_layer]
108
+ layers.append(activation_layer())
109
+
110
+ super().__init__(*layers)
111
+
112
+ def forward(self, inputs):
113
+ for module in self:
114
+ inputs = module(inputs)
115
+ return inputs
116
+
117
+
118
+ class CausalConvTranspose2d(nn.Sequential):
119
+ def __init__(self,
120
+ in_channels: int,
121
+ out_channels: int,
122
+ kernel_size: Union[int, Iterable[int]],
123
+ fstride: int = 1,
124
+ dilation: int = 1,
125
+ fpad: bool = True,
126
+ bias: bool = True,
127
+ separable: bool = False,
128
+ norm_layer: str = "batch_norm_2d",
129
+ activation_layer: str = "relu",
130
+ lookahead: int = 0
131
+ ):
132
+ """
133
+ Causal ConvTranspose2d.
134
+
135
+ Expected input format: [batch_size, channels, time_steps, spec_dim]
136
+ """
137
+ super(CausalConvTranspose2d, self).__init__()
138
+
139
+ kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
140
+
141
+ if fpad:
142
+ fpad_ = kernel_size[1] // 2
143
+ else:
144
+ fpad_ = 0
145
+
146
+ # for last 2 dim, pad (left, right, top, bottom).
147
+ pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
148
+
149
+ layers = []
150
+ if any(x > 0 for x in pad):
151
+ layers.append(nn.ConstantPad2d(pad, 0.0))
152
+
153
+ groups = math.gcd(in_channels, out_channels) if separable else 1
154
+ if groups == 1:
155
+ separable = False
156
+
157
+ layers.append(
158
+ nn.ConvTranspose2d(
159
+ in_channels,
160
+ out_channels,
161
+ kernel_size=kernel_size,
162
+ padding=(kernel_size[0] - 1, fpad_ + dilation - 1),
163
+ output_padding=(0, fpad_),
164
+ stride=(1, fstride), # stride over time is always 1
165
+ dilation=(1, dilation), # dilation over time is always 1
166
+ groups=groups,
167
+ bias=bias,
168
+ )
169
+ )
170
+
171
+ if separable:
172
+ layers.append(
173
+ nn.Conv2d(
174
+ out_channels,
175
+ out_channels,
176
+ kernel_size=1,
177
+ bias=False,
178
+ )
179
+ )
180
+
181
+ if norm_layer is not None:
182
+ norm_layer = norm_layer_dict[norm_layer]
183
+ layers.append(norm_layer(out_channels))
184
+
185
+ if activation_layer is not None:
186
+ activation_layer = activation_layer_dict[activation_layer]
187
+ layers.append(activation_layer())
188
+
189
+ super().__init__(*layers)
190
+
191
+
192
+ class GroupedLinear(nn.Module):
193
+
194
+ def __init__(self, input_size: int, hidden_size: int, groups: int = 1):
195
+ super().__init__()
196
+ # self.weight: Tensor
197
+ self.input_size = input_size
198
+ self.hidden_size = hidden_size
199
+ self.groups = groups
200
+ assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}"
201
+ assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}"
202
+ self.ws = input_size // groups
203
+ self.register_parameter(
204
+ "weight",
205
+ torch.nn.Parameter(
206
+ torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True
207
+ ),
208
+ )
209
+ self.reset_parameters()
210
+
211
+ def reset_parameters(self):
212
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore
213
+
214
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
215
+ # x: [..., I]
216
+ b, t, _ = x.shape
217
+ # new_shape = list(x.shape)[:-1] + [self.groups, self.ws]
218
+ new_shape = (b, t, self.groups, self.ws)
219
+ x = x.view(new_shape)
220
+ # The better way, but not supported by torchscript
221
+ # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G]
222
+ x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G]
223
+ x = x.flatten(2, 3) # [B, T, H]
224
+ return x
225
+
226
+ def __repr__(self):
227
+ cls = self.__class__.__name__
228
+ return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})"
229
+
230
+
231
+ class SqueezedGRU_S(nn.Module):
232
+ """
233
+ SGE net: Video object detection with squeezed GRU and information entropy map
234
+ https://arxiv.org/abs/2106.07224
235
+ """
236
+
237
+ def __init__(
238
+ self,
239
+ input_size: int,
240
+ hidden_size: int,
241
+ output_size: Optional[int] = None,
242
+ num_layers: int = 1,
243
+ linear_groups: int = 8,
244
+ batch_first: bool = True,
245
+ skip_op: str = "none",
246
+ activation_layer: str = "identity",
247
+ ):
248
+ super().__init__()
249
+ self.input_size = input_size
250
+ self.hidden_size = hidden_size
251
+
252
+ self.linear_in = nn.Sequential(
253
+ GroupedLinear(
254
+ input_size=input_size,
255
+ hidden_size=hidden_size,
256
+ groups=linear_groups,
257
+ ),
258
+ activation_layer_dict[activation_layer](),
259
+ )
260
+
261
+ # gru skip operator
262
+ self.gru_skip_op = None
263
+
264
+ if skip_op == "none":
265
+ self.gru_skip_op = None
266
+ elif skip_op == "identity":
267
+ if not input_size != output_size:
268
+ raise AssertionError("Dimensions do not match")
269
+ self.gru_skip_op = nn.Identity()
270
+ elif skip_op == "grouped_linear":
271
+ self.gru_skip_op = GroupedLinear(
272
+ input_size=hidden_size,
273
+ hidden_size=hidden_size,
274
+ groups=linear_groups,
275
+ )
276
+ else:
277
+ raise NotImplementedError()
278
+
279
+ self.gru = nn.GRU(
280
+ input_size=hidden_size,
281
+ hidden_size=hidden_size,
282
+ num_layers=num_layers,
283
+ batch_first=batch_first,
284
+ bidirectional=False,
285
+ )
286
+
287
+ if output_size is not None:
288
+ self.linear_out = nn.Sequential(
289
+ GroupedLinear(
290
+ input_size=hidden_size,
291
+ hidden_size=output_size,
292
+ groups=linear_groups,
293
+ ),
294
+ activation_layer_dict[activation_layer](),
295
+ )
296
+ else:
297
+ self.linear_out = nn.Identity()
298
+
299
+ def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]:
300
+ x = self.linear_in(inputs)
301
+
302
+ x, h = self.gru.forward(x, h)
303
+
304
+ x = self.linear_out(x)
305
+
306
+ if self.gru_skip_op is not None:
307
+ x = x + self.gru_skip_op(inputs)
308
+
309
+ return x, h
310
+
311
+
312
+ class Add(nn.Module):
313
+ def forward(self, a, b):
314
+ return a + b
315
+
316
+
317
+ class Concat(nn.Module):
318
+ def forward(self, a, b):
319
+ return torch.cat((a, b), dim=-1)
320
+
321
+
322
+ class Encoder(nn.Module):
323
+ def __init__(self, config: DfNetConfig):
324
+ super(Encoder, self).__init__()
325
+ self.embedding_input_size = config.conv_channels * config.spec_bins // 4
326
+ self.embedding_output_size = config.conv_channels * config.spec_bins // 4
327
+ self.embedding_hidden_size = config.embedding_hidden_size
328
+
329
+ self.spec_conv0 = CausalConv2d(
330
+ in_channels=1,
331
+ out_channels=config.conv_channels,
332
+ kernel_size=config.conv_kernel_size_input,
333
+ bias=False,
334
+ separable=True,
335
+ fstride=1,
336
+ lookahead=config.conv_lookahead,
337
+ )
338
+ self.spec_conv1 = CausalConv2d(
339
+ in_channels=config.conv_channels,
340
+ out_channels=config.conv_channels,
341
+ kernel_size=config.conv_kernel_size_inner,
342
+ bias=False,
343
+ separable=True,
344
+ fstride=2,
345
+ lookahead=config.conv_lookahead,
346
+ )
347
+ self.spec_conv2 = CausalConv2d(
348
+ in_channels=config.conv_channels,
349
+ out_channels=config.conv_channels,
350
+ kernel_size=config.conv_kernel_size_inner,
351
+ bias=False,
352
+ separable=True,
353
+ fstride=2,
354
+ lookahead=config.conv_lookahead,
355
+ )
356
+ self.spec_conv3 = CausalConv2d(
357
+ in_channels=config.conv_channels,
358
+ out_channels=config.conv_channels,
359
+ kernel_size=config.conv_kernel_size_inner,
360
+ bias=False,
361
+ separable=True,
362
+ fstride=1,
363
+ lookahead=config.conv_lookahead,
364
+ )
365
+
366
+ self.df_conv0 = CausalConv2d(
367
+ in_channels=2,
368
+ out_channels=config.conv_channels,
369
+ kernel_size=config.conv_kernel_size_input,
370
+ bias=False,
371
+ separable=True,
372
+ fstride=1,
373
+ )
374
+ self.df_conv1 = CausalConv2d(
375
+ in_channels=config.conv_channels,
376
+ out_channels=config.conv_channels,
377
+ kernel_size=config.conv_kernel_size_inner,
378
+ bias=False,
379
+ separable=True,
380
+ fstride=2,
381
+ )
382
+ self.df_fc_emb = nn.Sequential(
383
+ GroupedLinear(
384
+ config.conv_channels * config.df_bins // 2,
385
+ self.embedding_input_size,
386
+ groups=config.encoder_linear_groups
387
+ ),
388
+ nn.ReLU(inplace=True)
389
+ )
390
+
391
+ if config.encoder_combine_op == "concat":
392
+ self.embedding_input_size *= 2
393
+ self.combine = Concat()
394
+ else:
395
+ self.combine = Add()
396
+
397
+ # emb_gru
398
+ if config.spec_bins % 8 != 0:
399
+ raise AssertionError("spec_bins should be divisible by 8")
400
+
401
+ self.emb_gru = SqueezedGRU_S(
402
+ self.embedding_input_size,
403
+ self.embedding_hidden_size,
404
+ output_size=self.embedding_output_size,
405
+ num_layers=1,
406
+ batch_first=True,
407
+ skip_op=config.encoder_emb_skip_op,
408
+ linear_groups=config.encoder_emb_linear_groups,
409
+ activation_layer="relu",
410
+ )
411
+
412
+ # lsnr
413
+ self.lsnr_fc = nn.Sequential(
414
+ nn.Linear(self.embedding_output_size, 1),
415
+ nn.Sigmoid()
416
+ )
417
+ self.lsnr_scale = config.lsnr_max - config.lsnr_min
418
+ self.lsnr_offset = config.lsnr_min
419
+
420
+ def forward(self,
421
+ feat_power: torch.Tensor,
422
+ feat_spec: torch.Tensor,
423
+ hidden_state: torch.Tensor = None,
424
+ ):
425
+ # feat_power shape: (batch_size, 1, time_steps, spec_dim)
426
+ e0 = self.spec_conv0.forward(feat_power)
427
+ e1 = self.spec_conv1.forward(e0)
428
+ e2 = self.spec_conv2.forward(e1)
429
+ e3 = self.spec_conv3.forward(e2)
430
+ # e0 shape: [batch_size, channels, time_steps, spec_dim]
431
+ # e1 shape: [batch_size, channels, time_steps, spec_dim // 2]
432
+ # e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
433
+ # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
434
+
435
+ # feat_spec, shape: (batch_size, 2, time_steps, df_bins)
436
+ c0 = self.df_conv0(feat_spec)
437
+ c1 = self.df_conv1(c0)
438
+ # c0 shape: [batch_size, channels, time_steps, df_bins]
439
+ # c1 shape: [batch_size, channels, time_steps, df_bins // 2]
440
+
441
+ cemb = c1.permute(0, 2, 3, 1)
442
+ # cemb shape: [batch_size, time_steps, df_bins // 2, channels]
443
+ cemb = cemb.flatten(2)
444
+ # cemb shape: [batch_size, time_steps, df_bins // 2 * channels]
445
+ cemb = self.df_fc_emb(cemb)
446
+ # cemb shape: [batch_size, time_steps, spec_dim // 4 * channels]
447
+
448
+ # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
449
+ emb = e3.permute(0, 2, 3, 1)
450
+ # emb shape: [batch_size, time_steps, spec_dim // 4, channels]
451
+ emb = emb.flatten(2)
452
+ # emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
453
+
454
+ emb = self.combine(emb, cemb)
455
+ # if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
456
+ # if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
457
+
458
+ emb, h = self.emb_gru.forward(emb, hidden_state)
459
+ # emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
460
+ # h shape: [batch_size, 1, spec_dim]
461
+
462
+ lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
463
+ # lsnr shape: [batch_size, time_steps, 1]
464
+
465
+ return e0, e1, e2, e3, emb, c0, lsnr, h
466
+
467
+
468
+ class Decoder(nn.Module):
469
+ def __init__(self, config: DfNetConfig):
470
+ super(Decoder, self).__init__()
471
+
472
+ if config.spec_bins % 8 != 0:
473
+ raise AssertionError("spec_bins should be divisible by 8")
474
+
475
+ self.emb_in_dim = config.conv_channels * config.spec_bins // 4
476
+ self.emb_out_dim = config.conv_channels * config.spec_bins // 4
477
+ self.emb_hidden_dim = config.decoder_emb_hidden_size
478
+
479
+ self.emb_gru = SqueezedGRU_S(
480
+ self.emb_in_dim,
481
+ self.emb_hidden_dim,
482
+ output_size=self.emb_out_dim,
483
+ num_layers=config.decoder_emb_num_layers - 1,
484
+ batch_first=True,
485
+ skip_op=config.decoder_emb_skip_op,
486
+ linear_groups=config.decoder_emb_linear_groups,
487
+ activation_layer="relu",
488
+ )
489
+ self.conv3p = CausalConv2d(
490
+ in_channels=config.conv_channels,
491
+ out_channels=config.conv_channels,
492
+ kernel_size=1,
493
+ bias=False,
494
+ separable=True,
495
+ fstride=1,
496
+ lookahead=config.conv_lookahead,
497
+ )
498
+ self.convt3 = CausalConv2d(
499
+ in_channels=config.conv_channels,
500
+ out_channels=config.conv_channels,
501
+ kernel_size=config.conv_kernel_size_inner,
502
+ bias=False,
503
+ separable=True,
504
+ fstride=1,
505
+ lookahead=config.conv_lookahead,
506
+ )
507
+ self.conv2p = CausalConv2d(
508
+ in_channels=config.conv_channels,
509
+ out_channels=config.conv_channels,
510
+ kernel_size=1,
511
+ bias=False,
512
+ separable=True,
513
+ fstride=1,
514
+ lookahead=config.conv_lookahead,
515
+ )
516
+ self.convt2 = CausalConvTranspose2d(
517
+ in_channels=config.conv_channels,
518
+ out_channels=config.conv_channels,
519
+ kernel_size=config.convt_kernel_size_inner,
520
+ bias=False,
521
+ separable=True,
522
+ fstride=2,
523
+ lookahead=config.conv_lookahead,
524
+ )
525
+ self.conv1p = CausalConv2d(
526
+ in_channels=config.conv_channels,
527
+ out_channels=config.conv_channels,
528
+ kernel_size=1,
529
+ bias=False,
530
+ separable=True,
531
+ fstride=1,
532
+ lookahead=config.conv_lookahead,
533
+ )
534
+ self.convt1 = CausalConvTranspose2d(
535
+ in_channels=config.conv_channels,
536
+ out_channels=config.conv_channels,
537
+ kernel_size=config.convt_kernel_size_inner,
538
+ bias=False,
539
+ separable=True,
540
+ fstride=2,
541
+ lookahead=config.conv_lookahead,
542
+ )
543
+ self.conv0p = CausalConv2d(
544
+ in_channels=config.conv_channels,
545
+ out_channels=config.conv_channels,
546
+ kernel_size=1,
547
+ bias=False,
548
+ separable=True,
549
+ fstride=1,
550
+ lookahead=config.conv_lookahead,
551
+ )
552
+ self.conv0_out = CausalConv2d(
553
+ in_channels=config.conv_channels,
554
+ out_channels=1,
555
+ kernel_size=config.conv_kernel_size_inner,
556
+ activation_layer="sigmoid",
557
+ bias=False,
558
+ separable=True,
559
+ fstride=1,
560
+ lookahead=config.conv_lookahead,
561
+ )
562
+
563
+ def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor:
564
+ # Estimates erb mask
565
+ b, _, t, f8 = e3.shape
566
+
567
+ # emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels]
568
+ emb, _ = self.emb_gru(emb)
569
+ # emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
570
+ emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2)
571
+ e3 = self.convt3(self.conv3p(e3) + emb)
572
+ # e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
573
+ e2 = self.convt2(self.conv2p(e2) + e3)
574
+ # e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2]
575
+ e1 = self.convt1(self.conv1p(e1) + e2)
576
+ # e1 shape: [batch_size, conv_channels, time_steps, freq_dim]
577
+ mask = self.conv0_out(self.conv0p(e0) + e1)
578
+ # mask shape: [batch_size, 1, time_steps, freq_dim]
579
+ return mask
580
+
581
+
582
+ class DfDecoder(nn.Module):
583
+ def __init__(self, config: DfNetConfig):
584
+ super(DfDecoder, self).__init__()
585
+
586
+ self.embedding_input_size = config.conv_channels * config.spec_bins // 4
587
+ self.df_decoder_hidden_size = config.df_decoder_hidden_size
588
+ self.df_num_layers = config.df_num_layers
589
+
590
+ self.df_order = config.df_order
591
+
592
+ self.df_bins = config.df_bins
593
+ self.df_out_ch = config.df_order * 2
594
+
595
+ self.df_convp = CausalConv2d(
596
+ config.conv_channels,
597
+ self.df_out_ch,
598
+ fstride=1,
599
+ kernel_size=(config.df_pathway_kernel_size_t, 1),
600
+ separable=True,
601
+ bias=False,
602
+ )
603
+ self.df_gru = SqueezedGRU_S(
604
+ self.embedding_input_size,
605
+ self.df_decoder_hidden_size,
606
+ num_layers=self.df_num_layers,
607
+ batch_first=True,
608
+ skip_op="none",
609
+ activation_layer="relu",
610
+ )
611
+
612
+ if config.df_gru_skip == "none":
613
+ self.df_skip = None
614
+ elif config.df_gru_skip == "identity":
615
+ if config.embedding_hidden_size != config.df_decoder_hidden_size:
616
+ raise AssertionError("Dimensions do not match")
617
+ self.df_skip = nn.Identity()
618
+ elif config.df_gru_skip == "grouped_linear":
619
+ self.df_skip = GroupedLinear(
620
+ self.embedding_input_size,
621
+ self.df_decoder_hidden_size,
622
+ groups=config.df_decoder_linear_groups
623
+ )
624
+ else:
625
+ raise NotImplementedError()
626
+
627
+ self.df_out: nn.Module
628
+ out_dim = self.df_bins * self.df_out_ch
629
+
630
+ self.df_out = nn.Sequential(
631
+ GroupedLinear(
632
+ input_size=self.df_decoder_hidden_size,
633
+ hidden_size=out_dim,
634
+ groups=config.df_decoder_linear_groups
635
+ ),
636
+ nn.Tanh()
637
+ )
638
+ self.df_fc_a = nn.Sequential(
639
+ nn.Linear(self.df_decoder_hidden_size, 1),
640
+ nn.Sigmoid()
641
+ )
642
+
643
+ def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor:
644
+ # emb shape: [batch_size, time_steps, df_bins // 4 * channels]
645
+ b, t, _ = emb.shape
646
+ df_coefs, _ = self.df_gru(emb)
647
+ if self.df_skip is not None:
648
+ df_coefs = df_coefs + self.df_skip(emb)
649
+ # df_coefs shape: [batch_size, time_steps, df_decoder_hidden_size]
650
+
651
+ # c0 shape: [batch_size, channels, time_steps, df_bins]
652
+ c0 = self.df_convp(c0)
653
+ # c0 shape: [batch_size, df_order * 2, time_steps, df_bins]
654
+ c0 = c0.permute(0, 2, 3, 1)
655
+ # c0 shape: [batch_size, time_steps, df_bins, df_order * 2]
656
+
657
+ df_coefs = self.df_out(df_coefs) # [B, T, F*O*2], O: df_order
658
+ # df_coefs shape: [batch_size, time_steps, df_bins * df_order * 2]
659
+ df_coefs = df_coefs.view(b, t, self.df_bins, self.df_out_ch)
660
+ # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
661
+ df_coefs = df_coefs + c0
662
+ # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
663
+ return df_coefs
664
+
665
+
666
+ class DfOutputReshapeMF(nn.Module):
667
+ """Coefficients output reshape for multiframe/MultiFrameModule
668
+
669
+ Requires input of shape B, C, T, F, 2.
670
+ """
671
+
672
+ def __init__(self, df_order: int, df_bins: int):
673
+ super().__init__()
674
+ self.df_order = df_order
675
+ self.df_bins = df_bins
676
+
677
+ def forward(self, coefs: torch.Tensor) -> torch.Tensor:
678
+ # [B, T, F, O*2] -> [B, O, T, F, 2]
679
+ new_shape = list(coefs.shape)
680
+ new_shape[-1] = -1
681
+ new_shape.append(2)
682
+ coefs = coefs.view(new_shape)
683
+ coefs = coefs.permute(0, 3, 1, 2, 4)
684
+ return coefs
685
+
686
+
687
+ class Mask(nn.Module):
688
+ def __init__(self, use_post_filter: bool = False, eps: float = 1e-12):
689
+ super().__init__()
690
+ self.use_post_filter = use_post_filter
691
+ self.eps = eps
692
+
693
+ def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor:
694
+ """
695
+ Post-Filter
696
+
697
+ A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech.
698
+ https://arxiv.org/abs/2008.04259
699
+
700
+ :param mask: Real valued mask, typically of shape [B, C, T, F].
701
+ :param beta: Global gain factor.
702
+ :return:
703
+ """
704
+ mask_sin = mask * torch.sin(np.pi * mask / 2)
705
+ mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2))
706
+ return mask_pf
707
+
708
+ def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
709
+ # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
710
+
711
+ if not self.training and self.use_post_filter:
712
+ mask = self.post_filter(mask)
713
+
714
+ # mask shape: [batch_size, 1, time_steps, spec_bins]
715
+ mask = mask.unsqueeze(4)
716
+ # mask shape: [batch_size, 1, time_steps, spec_bins, 1]
717
+ return spec * mask
718
+
719
+
720
+ class DeepFiltering(nn.Module):
721
+ def __init__(self,
722
+ df_bins: int,
723
+ df_order: int,
724
+ lookahead: int = 0,
725
+ ):
726
+ super(DeepFiltering, self).__init__()
727
+ self.df_bins = df_bins
728
+ self.df_order = df_order
729
+ self.need_unfold = df_order > 1
730
+ self.lookahead = lookahead
731
+
732
+ self.pad = nn.ConstantPad2d((0, 0, df_order - 1 - lookahead, lookahead), 0.0)
733
+
734
+ def spec_unfold(self, spec: torch.Tensor):
735
+ """
736
+ Pads and unfolds the spectrogram according to frame_size.
737
+ :param spec: complex Tensor, Spectrogram of shape [B, C, T, F].
738
+ :return: Tensor, Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size.
739
+ """
740
+ if self.need_unfold:
741
+ # spec shape: [batch_size, spec_bins, time_steps]
742
+ spec_pad = self.pad(spec)
743
+ # spec_pad shape: [batch_size, 1, time_steps_pad, spec_bins]
744
+ spec_unfold = spec_pad.unfold(2, self.df_order, 1)
745
+ # spec_unfold shape: [batch_size, 1, time_steps, spec_bins, df_order]
746
+ return spec_unfold
747
+ else:
748
+ return spec.unsqueeze(-1)
749
+
750
+ def forward(self,
751
+ spec: torch.Tensor,
752
+ coefs: torch.Tensor,
753
+ ):
754
+ # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
755
+ spec_u = self.spec_unfold(torch.view_as_complex(spec))
756
+ # spec_u shape: [batch_size, 1, time_steps, spec_bins, df_order]
757
+
758
+ # coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
759
+ coefs = torch.view_as_complex(coefs)
760
+ # coefs shape: [batch_size, df_order, time_steps, df_bins]
761
+ spec_f = spec_u.narrow(-2, 0, self.df_bins)
762
+ # spec_f shape: [batch_size, 1, time_steps, df_bins, df_order]
763
+
764
+ coefs = coefs.view(coefs.shape[0], -1, self.df_order, *coefs.shape[2:])
765
+ # coefs shape: [batch_size, 1, df_order, time_steps, df_bins]
766
+
767
+ spec_f = self.df(spec_f, coefs)
768
+ # spec_f shape: [batch_size, 1, time_steps, df_bins]
769
+
770
+ if self.training:
771
+ spec = spec.clone()
772
+ spec[..., :self.df_bins, :] = torch.view_as_real(spec_f)
773
+ # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
774
+ return spec
775
+
776
+ @staticmethod
777
+ def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor:
778
+ """
779
+ Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram.
780
+ :param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N].
781
+ :param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F].
782
+ :return: (complex Tensor). Spectrogram of shape [B, C, T, F].
783
+ """
784
+ return torch.einsum("...tfn,...ntf->...tf", spec, coefs)
785
+
786
+
787
+ class DfNet(nn.Module):
788
+ def __init__(self, config: DfNetConfig):
789
+ super(DfNet, self).__init__()
790
+ self.config = config
791
+
792
+ self.stft = ConvSTFT(
793
+ nfft=config.nfft,
794
+ win_size=config.win_size,
795
+ hop_size=config.hop_size,
796
+ win_type=config.win_type,
797
+ feature_type="complex",
798
+ requires_grad=False
799
+ )
800
+ self.istft = ConviSTFT(
801
+ nfft=config.nfft,
802
+ win_size=config.win_size,
803
+ hop_size=config.hop_size,
804
+ win_type=config.win_type,
805
+ feature_type="complex",
806
+ requires_grad=False
807
+ )
808
+
809
+ self.encoder = Encoder(config)
810
+ self.decoder = Decoder(config)
811
+
812
+ self.df_decoder = DfDecoder(config)
813
+ self.df_out_transform = DfOutputReshapeMF(config.df_order, config.df_bins)
814
+ self.df_op = DeepFiltering(
815
+ df_bins=config.df_bins,
816
+ df_order=config.df_order,
817
+ lookahead=config.df_lookahead,
818
+ )
819
+
820
+ self.mask = Mask(use_post_filter=config.use_post_filter)
821
+
822
+ def forward(self,
823
+ spec_complex: torch.Tensor,
824
+ ):
825
+ feat_power = torch.square(torch.abs(spec_complex))
826
+ feat_power = feat_power.unsqueeze(1).permute(0, 1, 3, 2)
827
+ # feat_power shape: [batch_size, spec_bins, time_steps]
828
+ # feat_power shape: [batch_size, 1, spec_bins, time_steps]
829
+ # feat_power shape: [batch_size, 1, time_steps, spec_bins]
830
+ feat_power = feat_power.detach()
831
+
832
+ # spec shape: [batch_size, spec_bins, time_steps]
833
+ feat_spec = torch.view_as_real(spec_complex)
834
+ # spec shape: [batch_size, spec_bins, time_steps, 2]
835
+ feat_spec = feat_spec.permute(0, 3, 2, 1)
836
+ # feat_spec shape: [batch_size, 2, time_steps, spec_bins]
837
+ feat_spec = feat_spec[..., :self.df_decoder.df_bins]
838
+ # feat_spec shape: [batch_size, 2, time_steps, df_bins]
839
+ feat_spec = feat_spec.detach()
840
+
841
+ # spec shape: [batch_size, spec_bins, time_steps]
842
+ spec = torch.unsqueeze(spec_complex, dim=1)
843
+ # spec shape: [batch_size, 1, spec_bins, time_steps]
844
+ spec = spec.permute(0, 1, 3, 2)
845
+ # spec shape: [batch_size, 1, time_steps, spec_bins]
846
+ spec = torch.view_as_real(spec)
847
+ # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
848
+ spec = spec.detach()
849
+
850
+ e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
851
+
852
+ mask = self.decoder.forward(emb, e3, e2, e1, e0)
853
+ # mask shape: [batch_size, 1, time_steps, spec_bins]
854
+ if torch.any(mask > 1) or torch.any(mask < 0):
855
+ raise AssertionError
856
+
857
+ spec_m = self.mask.forward(spec, mask)
858
+
859
+ # lsnr shape: [batch_size, time_steps, 1]
860
+ lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
861
+ # lsnr shape: [batch_size, 1, time_steps]
862
+
863
+ df_coefs = self.df_decoder.forward(emb, c0)
864
+ df_coefs = self.df_out_transform(df_coefs)
865
+ # df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
866
+
867
+ spec_e = self.df_op.forward(spec.clone(), df_coefs)
868
+ # spec_e shape: [batch_size, 1, time_steps, spec_bins, 2]
869
+
870
+ spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
871
+
872
+ spec_e = torch.squeeze(spec_e, dim=1)
873
+ spec_e = spec_e.permute(0, 2, 1, 3)
874
+ # spec_e shape: [batch_size, spec_bins, time_steps, 2]
875
+
876
+ mask = torch.squeeze(mask, dim=1)
877
+ mask = mask.permute(0, 2, 1)
878
+ # mask shape: [batch_size, spec_bins, time_steps]
879
+
880
+ return spec_e, mask, lsnr
881
+
882
+
883
+ class DfNetPretrainedModel(DfNet):
884
+ def __init__(self,
885
+ config: DfNetConfig,
886
+ ):
887
+ super(DfNetPretrainedModel, self).__init__(
888
+ config=config,
889
+ )
890
+
891
+ @classmethod
892
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
893
+ config = DfNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
894
+
895
+ model = cls(config)
896
+
897
+ if os.path.isdir(pretrained_model_name_or_path):
898
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
899
+ else:
900
+ ckpt_file = pretrained_model_name_or_path
901
+
902
+ with open(ckpt_file, "rb") as f:
903
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
904
+ model.load_state_dict(state_dict, strict=True)
905
+ return model
906
+
907
+ def save_pretrained(self,
908
+ save_directory: Union[str, os.PathLike],
909
+ state_dict: Optional[dict] = None,
910
+ ):
911
+
912
+ model = self
913
+
914
+ if state_dict is None:
915
+ state_dict = model.state_dict()
916
+
917
+ os.makedirs(save_directory, exist_ok=True)
918
+
919
+ # save state dict
920
+ model_file = os.path.join(save_directory, MODEL_FILE)
921
+ torch.save(state_dict, model_file)
922
+
923
+ # save config
924
+ config_file = os.path.join(save_directory, CONFIG_FILE)
925
+ self.config.to_yaml_file(config_file)
926
+ return save_directory
927
+
928
+
929
+ def main():
930
+
931
+ transformer = torchaudio.transforms.Spectrogram(
932
+ n_fft=512,
933
+ win_length=200,
934
+ hop_length=80,
935
+ window_fn=torch.hamming_window,
936
+ power=None,
937
+ )
938
+
939
+ config = DfNetConfig()
940
+ model = DfNetPretrainedModel(config=config)
941
+
942
+ inputs = torch.randn(size=(1, 16000), dtype=torch.float32)
943
+ spec_complex = transformer.forward(inputs)
944
+ spec_complex = spec_complex[:, :-1, :]
945
+
946
+ output = model.forward(spec_complex)
947
+ print(output[1].shape)
948
+ return
949
+
950
+
951
+ if __name__ == "__main__":
952
+ main()
toolbox/torchaudio/models/frcrn/modeling_frcrn.py CHANGED
@@ -57,6 +57,7 @@ class FRCRN(nn.Module):
57
  nfft=self.nfft,
58
  win_size=self.win_size,
59
  hop_size=self.hop_size,
 
60
  feature_type="complex",
61
  requires_grad=False
62
  )
@@ -194,7 +195,7 @@ class FRCRN(nn.Module):
194
  }]
195
  return params
196
 
197
- def mag_pha_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
198
  """
199
 
200
  :param est_mask: torch.Tensor, shape: [b, n+2, t]
@@ -230,10 +231,11 @@ class FRCRN(nn.Module):
230
  mask_re = est_mask[:, :self.freq_bins, :]
231
  mask_im = est_mask[:, self.freq_bins:, :]
232
 
233
- amp_loss = F.mse_loss(gth_mask_re, mask_re)
234
- phase_loss = F.mse_loss(gth_mask_im, mask_im)
235
 
236
- return amp_loss, phase_loss
 
237
 
238
 
239
  MODEL_FILE = "model.pt"
 
57
  nfft=self.nfft,
58
  win_size=self.win_size,
59
  hop_size=self.hop_size,
60
+ win_type=self.win_type,
61
  feature_type="complex",
62
  requires_grad=False
63
  )
 
195
  }]
196
  return params
197
 
198
+ def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
199
  """
200
 
201
  :param est_mask: torch.Tensor, shape: [b, n+2, t]
 
231
  mask_re = est_mask[:, :self.freq_bins, :]
232
  mask_im = est_mask[:, self.freq_bins:, :]
233
 
234
+ loss_re = F.mse_loss(gth_mask_re, mask_re)
235
+ loss_im = F.mse_loss(gth_mask_im, mask_im)
236
 
237
+ loss = loss_re + loss_im
238
+ return loss
239
 
240
 
241
  MODEL_FILE = "model.pt"
toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py CHANGED
@@ -929,5 +929,5 @@ def main():
929
  return
930
 
931
 
932
- if __name__ == '__main__':
933
  main()
 
929
  return
930
 
931
 
932
+ if __name__ == "__main__":
933
  main()