HoneyTian commited on
Commit
94ba8b5
·
1 Parent(s): 909a27e
examples/dfnet/run.sh CHANGED
@@ -2,6 +2,9 @@
2
 
3
  : <<'END'
4
 
 
 
 
5
 
6
  sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet-dns3 \
7
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
 
2
 
3
  : <<'END'
4
 
5
+ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name dfnet-nx-speech \
6
+ --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
7
+ --speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech"
8
 
9
  sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet-dns3 \
10
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
examples/dfnet/step_1_prepare_data.py CHANGED
@@ -104,7 +104,7 @@ def main():
104
  dataset = list()
105
 
106
  count = 0
107
- process_bar = tqdm(desc="build dataset excel")
108
  with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
109
  for noise, speech in zip(noise_generator, speech_generator):
110
  if count >= args.max_count > 0:
 
104
  dataset = list()
105
 
106
  count = 0
107
+ process_bar = tqdm(desc="build dataset jsonl")
108
  with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
109
  for noise, speech in zip(noise_generator, speech_generator):
110
  if count >= args.max_count > 0:
examples/dfnet/step_2_train_model.py CHANGED
@@ -25,6 +25,8 @@ from tqdm import tqdm
25
  from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
26
  from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
27
  from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
 
 
28
  from toolbox.torchaudio.metrics.pesq import run_pesq_score
29
  from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
30
  from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNet, DfNetPretrainedModel
@@ -79,22 +81,20 @@ class CollateFunction(object):
79
  # noise_wave: torch.Tensor = sample["noise_wave"]
80
  clean_audio: torch.Tensor = sample["speech_wave"]
81
  noisy_audio: torch.Tensor = sample["mix_wave"]
82
- snr_db: float = sample["snr_db"]
83
 
84
  clean_audios.append(clean_audio)
85
  noisy_audios.append(noisy_audio)
86
- snr_db_list.append(snr_db)
87
 
88
  clean_audios = torch.stack(clean_audios)
89
  noisy_audios = torch.stack(noisy_audios)
90
- snr_db_list = torch.tensor(snr_db_list)
91
 
92
  # assert
93
  if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
94
  raise AssertionError("nan or inf in clean_audios")
95
  if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
96
  raise AssertionError("nan or inf in noisy_audios")
97
- return clean_audios, noisy_audios, snr_db_list
98
 
99
 
100
  collate_fn = CollateFunction()
@@ -146,7 +146,7 @@ def main():
146
  num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
147
  collate_fn=collate_fn,
148
  pin_memory=False,
149
- prefetch_factor=2,
150
  )
151
  valid_data_loader = DataLoader(
152
  dataset=valid_dataset,
@@ -157,7 +157,7 @@ def main():
157
  num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
158
  collate_fn=collate_fn,
159
  pin_memory=False,
160
- prefetch_factor=2,
161
  )
162
 
163
  # models
@@ -222,7 +222,6 @@ def main():
222
  factor_mag=1.0,
223
  reduction="mean"
224
  ).to(device)
225
- lsnr_loss_fn = nn.L1Loss(reduction="mean")
226
 
227
  # training loop
228
 
@@ -247,8 +246,10 @@ def main():
247
 
248
  total_pesq_score = 0.
249
  total_loss = 0.
 
250
  total_neg_si_snr_loss = 0.
251
  total_mask_loss = 0.
 
252
  total_batches = 0.
253
 
254
  progress_bar_train = tqdm(
@@ -256,20 +257,18 @@ def main():
256
  desc="Training; epoch-{}".format(epoch_idx),
257
  )
258
  for train_batch in train_data_loader:
259
- clean_audios, noisy_audios, snr_db_list = train_batch
260
  clean_audios: torch.Tensor = clean_audios.to(device)
261
  noisy_audios: torch.Tensor = noisy_audios.to(device)
262
- snr_db_list: torch.Tensor = snr_db_list.to(device)
263
 
264
  est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
265
 
266
- print(f"est_mask.shape: {est_mask.shape}, est_mask.dtype: {est_mask.dtype}")
267
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
268
  mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
269
- # mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
270
- # neg_si_snr_loss = lsnr_loss_fn.forward(lsnr, snr_db_list)
271
 
272
- loss = 1.0 * neg_si_snr_loss + 1.0 * mask_loss
273
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
274
  logger.info(f"find nan or inf in loss.")
275
  continue
@@ -286,22 +285,28 @@ def main():
286
 
287
  total_pesq_score += pesq_score
288
  total_loss += loss.item()
 
289
  total_neg_si_snr_loss += neg_si_snr_loss.item()
290
  total_mask_loss += mask_loss.item()
 
291
  total_batches += 1
292
 
293
  average_pesq_score = round(total_pesq_score / total_batches, 4)
294
  average_loss = round(total_loss / total_batches, 4)
 
295
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
296
  average_mask_loss = round(total_mask_loss / total_batches, 4)
 
297
 
298
  progress_bar_train.update(1)
299
  progress_bar_train.set_postfix({
300
  "lr": lr_scheduler.get_last_lr()[0],
301
  "pesq_score": average_pesq_score,
302
  "loss": average_loss,
 
303
  "neg_si_snr_loss": average_neg_si_snr_loss,
304
  "mask_loss": average_mask_loss,
 
305
  })
306
 
307
  # evaluation
@@ -312,8 +317,10 @@ def main():
312
 
313
  total_pesq_score = 0.
314
  total_loss = 0.
 
315
  total_neg_si_snr_loss = 0.
316
  total_mask_loss = 0.
 
317
  total_batches = 0.
318
 
319
  progress_bar_train.close()
@@ -321,17 +328,18 @@ def main():
321
  desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
322
  )
323
  for eval_batch in valid_data_loader:
324
- clean_audios, noisy_audios, snr_db_list = eval_batch
325
  clean_audios: torch.Tensor = clean_audios.to(device)
326
  noisy_audios: torch.Tensor = noisy_audios.to(device)
327
- snr_db_list: torch.Tensor = snr_db_list.to(device)
328
 
329
  est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
330
 
 
331
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
332
  mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
 
333
 
334
- loss = 1.0 * neg_si_snr_loss + 1.0 * mask_loss
335
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
336
  logger.info(f"find nan or inf in loss.")
337
  continue
@@ -342,28 +350,36 @@ def main():
342
 
343
  total_pesq_score += pesq_score
344
  total_loss += loss.item()
 
345
  total_neg_si_snr_loss += neg_si_snr_loss.item()
346
  total_mask_loss += mask_loss.item()
 
347
  total_batches += 1
348
 
349
  average_pesq_score = round(total_pesq_score / total_batches, 4)
350
  average_loss = round(total_loss / total_batches, 4)
 
351
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
352
  average_mask_loss = round(total_mask_loss / total_batches, 4)
 
353
 
354
  progress_bar_eval.update(1)
355
  progress_bar_eval.set_postfix({
356
  "lr": lr_scheduler.get_last_lr()[0],
357
  "pesq_score": average_pesq_score,
358
  "loss": average_loss,
 
359
  "neg_si_snr_loss": average_neg_si_snr_loss,
360
  "mask_loss": average_mask_loss,
 
361
  })
362
 
363
  total_pesq_score = 0.
364
  total_loss = 0.
 
365
  total_neg_si_snr_loss = 0.
366
  total_mask_loss = 0.
 
367
  total_batches = 0.
368
 
369
  progress_bar_eval.close()
@@ -393,7 +409,7 @@ def main():
393
  best_epoch_idx = epoch_idx
394
  best_step_idx = step_idx
395
  best_metric = average_pesq_score
396
- elif average_pesq_score > best_metric:
397
  # great is better.
398
  best_epoch_idx = epoch_idx
399
  best_step_idx = step_idx
@@ -407,8 +423,10 @@ def main():
407
  "best_step_idx": best_step_idx,
408
  "pesq_score": average_pesq_score,
409
  "loss": average_loss,
 
410
  "neg_si_snr_loss": average_neg_si_snr_loss,
411
  "mask_loss": average_mask_loss,
 
412
  }
413
  metrics_filename = save_dir / "metrics_epoch.json"
414
  with open(metrics_filename, "w", encoding="utf-8") as f:
 
25
  from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
26
  from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
27
  from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
28
+ from toolbox.torchaudio.losses.irm import IRMLoss
29
+ from toolbox.torchaudio.losses.snr import LocalSNRLoss
30
  from toolbox.torchaudio.metrics.pesq import run_pesq_score
31
  from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
32
  from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNet, DfNetPretrainedModel
 
81
  # noise_wave: torch.Tensor = sample["noise_wave"]
82
  clean_audio: torch.Tensor = sample["speech_wave"]
83
  noisy_audio: torch.Tensor = sample["mix_wave"]
84
+ # snr_db: float = sample["snr_db"]
85
 
86
  clean_audios.append(clean_audio)
87
  noisy_audios.append(noisy_audio)
 
88
 
89
  clean_audios = torch.stack(clean_audios)
90
  noisy_audios = torch.stack(noisy_audios)
 
91
 
92
  # assert
93
  if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
94
  raise AssertionError("nan or inf in clean_audios")
95
  if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
96
  raise AssertionError("nan or inf in noisy_audios")
97
+ return clean_audios, noisy_audios
98
 
99
 
100
  collate_fn = CollateFunction()
 
146
  num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
147
  collate_fn=collate_fn,
148
  pin_memory=False,
149
+ prefetch_factor=None if platform.system() == "Windows" else 2,
150
  )
151
  valid_data_loader = DataLoader(
152
  dataset=valid_dataset,
 
157
  num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
158
  collate_fn=collate_fn,
159
  pin_memory=False,
160
+ prefetch_factor=None if platform.system() == "Windows" else 2,
161
  )
162
 
163
  # models
 
222
  factor_mag=1.0,
223
  reduction="mean"
224
  ).to(device)
 
225
 
226
  # training loop
227
 
 
246
 
247
  total_pesq_score = 0.
248
  total_loss = 0.
249
+ total_mr_stft_loss = 0.
250
  total_neg_si_snr_loss = 0.
251
  total_mask_loss = 0.
252
+ total_lsnr_loss = 0.
253
  total_batches = 0.
254
 
255
  progress_bar_train = tqdm(
 
257
  desc="Training; epoch-{}".format(epoch_idx),
258
  )
259
  for train_batch in train_data_loader:
260
+ clean_audios, noisy_audios = train_batch
261
  clean_audios: torch.Tensor = clean_audios.to(device)
262
  noisy_audios: torch.Tensor = noisy_audios.to(device)
 
263
 
264
  est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
265
 
266
+ mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
267
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
268
  mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
269
+ lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
 
270
 
271
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 1.0 * lsnr_loss
272
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
273
  logger.info(f"find nan or inf in loss.")
274
  continue
 
285
 
286
  total_pesq_score += pesq_score
287
  total_loss += loss.item()
288
+ total_mr_stft_loss += mr_stft_loss.item()
289
  total_neg_si_snr_loss += neg_si_snr_loss.item()
290
  total_mask_loss += mask_loss.item()
291
+ total_lsnr_loss += lsnr_loss.item()
292
  total_batches += 1
293
 
294
  average_pesq_score = round(total_pesq_score / total_batches, 4)
295
  average_loss = round(total_loss / total_batches, 4)
296
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
297
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
298
  average_mask_loss = round(total_mask_loss / total_batches, 4)
299
+ average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
300
 
301
  progress_bar_train.update(1)
302
  progress_bar_train.set_postfix({
303
  "lr": lr_scheduler.get_last_lr()[0],
304
  "pesq_score": average_pesq_score,
305
  "loss": average_loss,
306
+ "mr_stft_loss": average_mr_stft_loss,
307
  "neg_si_snr_loss": average_neg_si_snr_loss,
308
  "mask_loss": average_mask_loss,
309
+ "lsnr_loss": average_lsnr_loss,
310
  })
311
 
312
  # evaluation
 
317
 
318
  total_pesq_score = 0.
319
  total_loss = 0.
320
+ total_mr_stft_loss = 0.
321
  total_neg_si_snr_loss = 0.
322
  total_mask_loss = 0.
323
+ total_lsnr_loss = 0.
324
  total_batches = 0.
325
 
326
  progress_bar_train.close()
 
328
  desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
329
  )
330
  for eval_batch in valid_data_loader:
331
+ clean_audios, noisy_audios = eval_batch
332
  clean_audios: torch.Tensor = clean_audios.to(device)
333
  noisy_audios: torch.Tensor = noisy_audios.to(device)
 
334
 
335
  est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
336
 
337
+ mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
338
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
339
  mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
340
+ lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
341
 
342
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 1.0 * lsnr_loss
343
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
344
  logger.info(f"find nan or inf in loss.")
345
  continue
 
350
 
351
  total_pesq_score += pesq_score
352
  total_loss += loss.item()
353
+ total_mr_stft_loss += mr_stft_loss.item()
354
  total_neg_si_snr_loss += neg_si_snr_loss.item()
355
  total_mask_loss += mask_loss.item()
356
+ total_lsnr_loss += lsnr_loss.item()
357
  total_batches += 1
358
 
359
  average_pesq_score = round(total_pesq_score / total_batches, 4)
360
  average_loss = round(total_loss / total_batches, 4)
361
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
362
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
363
  average_mask_loss = round(total_mask_loss / total_batches, 4)
364
+ average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
365
 
366
  progress_bar_eval.update(1)
367
  progress_bar_eval.set_postfix({
368
  "lr": lr_scheduler.get_last_lr()[0],
369
  "pesq_score": average_pesq_score,
370
  "loss": average_loss,
371
+ "mr_stft_loss": average_mr_stft_loss,
372
  "neg_si_snr_loss": average_neg_si_snr_loss,
373
  "mask_loss": average_mask_loss,
374
+ "lsnr_loss": average_lsnr_loss,
375
  })
376
 
377
  total_pesq_score = 0.
378
  total_loss = 0.
379
+ total_mr_stft_loss = 0.
380
  total_neg_si_snr_loss = 0.
381
  total_mask_loss = 0.
382
+ total_lsnr_loss = 0.
383
  total_batches = 0.
384
 
385
  progress_bar_eval.close()
 
409
  best_epoch_idx = epoch_idx
410
  best_step_idx = step_idx
411
  best_metric = average_pesq_score
412
+ elif average_pesq_score >= best_metric:
413
  # great is better.
414
  best_epoch_idx = epoch_idx
415
  best_step_idx = step_idx
 
423
  "best_step_idx": best_step_idx,
424
  "pesq_score": average_pesq_score,
425
  "loss": average_loss,
426
+ "mr_stft_loss": average_mr_stft_loss,
427
  "neg_si_snr_loss": average_neg_si_snr_loss,
428
  "mask_loss": average_mask_loss,
429
+ "lsnr_loss": average_lsnr_loss,
430
  }
431
  metrics_filename = save_dir / "metrics_epoch.json"
432
  with open(metrics_filename, "w", encoding="utf-8") as f:
examples/dfnet/yaml/config.yaml CHANGED
@@ -31,10 +31,6 @@ encoder_emb_hidden_size: 256
31
 
32
  encoder_linear_groups: 32
33
 
34
- lsnr_max: 30
35
- lsnr_min: -15
36
- norm_tau: 1.
37
-
38
  decoder_emb_num_layers: 3
39
  decoder_emb_skip_op: "none"
40
  decoder_emb_linear_groups: 16
@@ -49,8 +45,15 @@ df_decoder_linear_groups: 16
49
  df_pathway_kernel_size_t: 5
50
  df_lookahead: 2
51
 
52
- # runtime
53
- use_post_filter: true
 
 
 
 
 
 
 
54
 
55
  # train
56
  lr: 0.001
@@ -63,9 +66,9 @@ max_epochs: 100
63
  clip_grad_norm: 10.0
64
  seed: 1234
65
 
66
- min_snr_db: -10
67
- max_snr_db: 20
68
-
69
  num_workers: 8
70
- batch_size: 32
71
  eval_steps: 10000
 
 
 
 
31
 
32
  encoder_linear_groups: 32
33
 
 
 
 
 
34
  decoder_emb_num_layers: 3
35
  decoder_emb_skip_op: "none"
36
  decoder_emb_linear_groups: 16
 
45
  df_pathway_kernel_size_t: 5
46
  df_lookahead: 2
47
 
48
+ # lsnr
49
+ n_frame: 3
50
+ lsnr_max: 30
51
+ lsnr_min: -15
52
+ norm_tau: 1.
53
+
54
+ # data
55
+ min_snr_db: -10
56
+ max_snr_db: 20
57
 
58
  # train
59
  lr: 0.001
 
66
  clip_grad_norm: 10.0
67
  seed: 1234
68
 
 
 
 
69
  num_workers: 8
70
+ batch_size: 4
71
  eval_steps: 10000
72
+
73
+ # runtime
74
+ use_post_filter: true
examples/frcrn/step_1_prepare_data.py CHANGED
@@ -104,7 +104,7 @@ def main():
104
  dataset = list()
105
 
106
  count = 0
107
- process_bar = tqdm(desc="build dataset excel")
108
  with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
109
  for noise, speech in zip(noise_generator, speech_generator):
110
  flag = random.random()
 
104
  dataset = list()
105
 
106
  count = 0
107
+ process_bar = tqdm(desc="build dataset jsonl")
108
  with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
109
  for noise, speech in zip(noise_generator, speech_generator):
110
  flag = random.random()
toolbox/torchaudio/losses/irm.py CHANGED
@@ -93,6 +93,69 @@ class CIRMLoss(nn.Module):
93
  return loss
94
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def main():
97
  batch_size = 2
98
  signal_length = 16000
 
93
  return loss
94
 
95
 
96
+ class IRMLoss(nn.Module):
97
+ """
98
+ https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/loss.py#L25
99
+ """
100
+ def __init__(self,
101
+ n_fft: int = 512,
102
+ win_size: int = 512,
103
+ hop_size: int = 256,
104
+ center: bool = True,
105
+ eps: float = 1e-8,
106
+ reduction: str = "mean",
107
+ ):
108
+ super(IRMLoss, self).__init__()
109
+ self.n_fft = n_fft
110
+ self.win_size = win_size
111
+ self.hop_size = hop_size
112
+ self.center = center
113
+ self.eps = eps
114
+ self.reduction = reduction
115
+
116
+ self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False)
117
+
118
+ if reduction not in ("sum", "mean"):
119
+ raise AssertionError(f"param reduction must be sum or mean.")
120
+
121
+ def forward(self, mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
122
+ if noisy.shape != clean.shape:
123
+ raise AssertionError("Input signals must have the same shape")
124
+ noise = noisy - clean
125
+
126
+ # clean_stft, noisy_stft shape: [b, f, t]
127
+ stft_clean = torch.stft(
128
+ clean,
129
+ n_fft=self.n_fft,
130
+ win_length=self.win_size,
131
+ hop_length=self.hop_size,
132
+ window=self.window,
133
+ center=self.center,
134
+ pad_mode="reflect",
135
+ normalized=False,
136
+ return_complex=True
137
+ )
138
+ stft_noise = torch.stft(
139
+ noise,
140
+ n_fft=self.n_fft,
141
+ win_length=self.win_size,
142
+ hop_length=self.hop_size,
143
+ window=self.window,
144
+ center=self.center,
145
+ pad_mode="reflect",
146
+ normalized=False,
147
+ return_complex=True
148
+ )
149
+
150
+ mag_clean = torch.abs(stft_clean)
151
+ mag_noise = torch.abs(stft_noise)
152
+
153
+ gth_irm_mask = (mag_clean / (mag_clean + mag_noise + self.eps)).clamp(0, 1)
154
+
155
+ loss = F.l1_loss(gth_irm_mask, mask, reduction=self.reduction)
156
+ return loss
157
+
158
+
159
  def main():
160
  batch_size = 2
161
  signal_length = 16000
toolbox/torchaudio/losses/snr.py CHANGED
@@ -5,6 +5,9 @@ https://zhuanlan.zhihu.com/p/627039860
5
  """
6
  import torch
7
  import torch.nn as nn
 
 
 
8
 
9
 
10
  class NegativeSNRLoss(nn.Module):
@@ -83,6 +86,86 @@ class NegativeSISNRLoss(nn.Module):
83
  return -loss
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def main():
87
  batch_size = 2
88
  signal_length = 16000
 
5
  """
6
  import torch
7
  import torch.nn as nn
8
+ from torch.nn import functional as F
9
+
10
+ from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
11
 
12
 
13
  class NegativeSNRLoss(nn.Module):
 
86
  return -loss
87
 
88
 
89
+ class LocalSNRLoss(nn.Module):
90
+ """
91
+ https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/modules.py#L816
92
+
93
+ """
94
+ def __init__(self,
95
+ sample_rate: int = 8000,
96
+ nfft: int = 512,
97
+ win_size: int = 512,
98
+ hop_size: int = 256,
99
+ n_frame: int = 3,
100
+ min_local_snr: int = -15,
101
+ max_local_snr: int = 30,
102
+ db: bool = True,
103
+ factor: float = 1,
104
+ reduction: str = "mean",
105
+ eps: float = 1e-8,
106
+ ):
107
+ super(LocalSNRLoss, self).__init__()
108
+ self.sample_rate = sample_rate
109
+ self.nfft = nfft
110
+ self.win_size = win_size
111
+ self.hop_size = hop_size
112
+
113
+ self.factor = factor
114
+ self.reduction = reduction
115
+ self.eps = eps
116
+
117
+ self.lsnr_fn = LocalSnrTarget(
118
+ sample_rate=sample_rate,
119
+ nfft=nfft,
120
+ win_size=win_size,
121
+ hop_size=hop_size,
122
+ n_frame=n_frame,
123
+ min_local_snr=min_local_snr,
124
+ max_local_snr=max_local_snr,
125
+ db=db,
126
+ )
127
+
128
+ self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False)
129
+
130
+ def forward(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
131
+ if clean.shape != noisy.shape:
132
+ raise AssertionError("Input signals must have the same shape")
133
+ noise = noisy - clean
134
+
135
+ stft_clean = torch.stft(
136
+ clean,
137
+ n_fft=self.nfft,
138
+ win_length=self.win_size,
139
+ hop_length=self.hop_size,
140
+ window=self.window,
141
+ center=self.center,
142
+ pad_mode="reflect",
143
+ normalized=False,
144
+ return_complex=True
145
+ )
146
+ stft_noise = torch.stft(
147
+ noise,
148
+ n_fft=self.nfft,
149
+ win_length=self.win_size,
150
+ hop_length=self.hop_size,
151
+ window=self.window,
152
+ center=self.center,
153
+ pad_mode="reflect",
154
+ normalized=False,
155
+ return_complex=True
156
+ )
157
+
158
+ # lsnr shape: [b, 1, t]
159
+ lsnr = lsnr.squeeze(1)
160
+ # lsnr shape: [b, t]
161
+
162
+ lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
163
+ # lsnr_gth shape: [b, t]
164
+
165
+ loss = F.mse_loss(lsnr, lsnr_gth) * self.factor
166
+ return loss
167
+
168
+
169
  def main():
170
  batch_size = 2
171
  signal_length = 16000
toolbox/torchaudio/models/dfnet/configuration_dfnet.py CHANGED
@@ -31,10 +31,6 @@ class DfNetConfig(PretrainedConfig):
31
 
32
  encoder_linear_groups: int = 32,
33
 
34
- lsnr_max: int = 30,
35
- lsnr_min: int = -15,
36
- norm_tau: float = 1.,
37
-
38
  decoder_emb_num_layers: int = 3,
39
  decoder_emb_skip_op: str = "none",
40
  decoder_emb_linear_groups: int = 16,
@@ -49,7 +45,13 @@ class DfNetConfig(PretrainedConfig):
49
  df_pathway_kernel_size_t: int = 5,
50
  df_lookahead: int = 2,
51
 
52
- use_post_filter: bool = False,
 
 
 
 
 
 
53
 
54
  lr: float = 0.001,
55
  lr_scheduler: str = "CosineAnnealingLR",
@@ -59,13 +61,12 @@ class DfNetConfig(PretrainedConfig):
59
  clip_grad_norm: float = 10.,
60
  seed: int = 1234,
61
 
62
- min_snr_db: float = -10,
63
- max_snr_db: float = 20,
64
-
65
  num_workers: int = 4,
66
  batch_size: int = 4,
67
  eval_steps: int = 25000,
68
 
 
 
69
  **kwargs
70
  ):
71
  super(DfNetConfig, self).__init__(**kwargs)
@@ -97,10 +98,6 @@ class DfNetConfig(PretrainedConfig):
97
  self.encoder_linear_groups = encoder_linear_groups
98
  self.encoder_combine_op = encoder_combine_op
99
 
100
- self.lsnr_max = lsnr_max
101
- self.lsnr_min = lsnr_min
102
- self.norm_tau = norm_tau
103
-
104
  # decoder
105
  self.decoder_emb_num_layers = decoder_emb_num_layers
106
  self.decoder_emb_skip_op = decoder_emb_skip_op
@@ -117,10 +114,17 @@ class DfNetConfig(PretrainedConfig):
117
  self.df_pathway_kernel_size_t = df_pathway_kernel_size_t
118
  self.df_lookahead = df_lookahead
119
 
120
- # runtime
121
- self.use_post_filter = use_post_filter
 
 
 
 
 
 
 
122
 
123
- #
124
  self.lr = lr
125
  self.lr_scheduler = lr_scheduler
126
  self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
@@ -129,13 +133,13 @@ class DfNetConfig(PretrainedConfig):
129
  self.clip_grad_norm = clip_grad_norm
130
  self.seed = seed
131
 
132
- self.min_snr_db = min_snr_db
133
- self.max_snr_db = max_snr_db
134
-
135
  self.num_workers = num_workers
136
  self.batch_size = batch_size
137
  self.eval_steps = eval_steps
138
 
 
 
 
139
 
140
  if __name__ == "__main__":
141
  pass
 
31
 
32
  encoder_linear_groups: int = 32,
33
 
 
 
 
 
34
  decoder_emb_num_layers: int = 3,
35
  decoder_emb_skip_op: str = "none",
36
  decoder_emb_linear_groups: int = 16,
 
45
  df_pathway_kernel_size_t: int = 5,
46
  df_lookahead: int = 2,
47
 
48
+ n_frame: int = 3,
49
+ max_local_snr: int = 30,
50
+ min_local_snr: int = -15,
51
+ norm_tau: float = 1.,
52
+
53
+ min_snr_db: float = -10,
54
+ max_snr_db: float = 20,
55
 
56
  lr: float = 0.001,
57
  lr_scheduler: str = "CosineAnnealingLR",
 
61
  clip_grad_norm: float = 10.,
62
  seed: int = 1234,
63
 
 
 
 
64
  num_workers: int = 4,
65
  batch_size: int = 4,
66
  eval_steps: int = 25000,
67
 
68
+ use_post_filter: bool = False,
69
+
70
  **kwargs
71
  ):
72
  super(DfNetConfig, self).__init__(**kwargs)
 
98
  self.encoder_linear_groups = encoder_linear_groups
99
  self.encoder_combine_op = encoder_combine_op
100
 
 
 
 
 
101
  # decoder
102
  self.decoder_emb_num_layers = decoder_emb_num_layers
103
  self.decoder_emb_skip_op = decoder_emb_skip_op
 
114
  self.df_pathway_kernel_size_t = df_pathway_kernel_size_t
115
  self.df_lookahead = df_lookahead
116
 
117
+ # lsnr
118
+ self.n_frame = n_frame
119
+ self.max_local_snr = max_local_snr
120
+ self.min_local_snr = min_local_snr
121
+ self.norm_tau = norm_tau
122
+
123
+ # data snr
124
+ self.min_snr_db = min_snr_db
125
+ self.max_snr_db = max_snr_db
126
 
127
+ # train
128
  self.lr = lr
129
  self.lr_scheduler = lr_scheduler
130
  self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
 
133
  self.clip_grad_norm = clip_grad_norm
134
  self.seed = seed
135
 
 
 
 
136
  self.num_workers = num_workers
137
  self.batch_size = batch_size
138
  self.eval_steps = eval_steps
139
 
140
+ # runtime
141
+ self.use_post_filter = use_post_filter
142
+
143
 
144
  if __name__ == "__main__":
145
  pass
toolbox/torchaudio/models/dfnet/conv_stft.py CHANGED
@@ -8,6 +8,7 @@ import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
  from scipy.signal import get_window
 
11
 
12
 
13
  def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False):
@@ -40,7 +41,7 @@ class ConvSTFT(nn.Module):
40
  win_size: int,
41
  hop_size: int,
42
  win_type: str = "hamming",
43
- feature_type: str = "real",
44
  requires_grad: bool = False):
45
  super(ConvSTFT, self).__init__()
46
 
@@ -57,23 +58,29 @@ class ConvSTFT(nn.Module):
57
 
58
  self.stride = hop_size
59
  self.dim = self.nfft
60
- self.feature_type = feature_type
61
 
62
  def forward(self, inputs: torch.Tensor):
63
  if inputs.dim() == 2:
64
  inputs = torch.unsqueeze(inputs, 1)
65
 
66
- outputs = F.conv1d(inputs, self.weight, stride=self.stride)
 
 
 
 
67
 
68
- if self.feature_type == "complex":
69
- return outputs
70
- else:
71
- dim = self.dim // 2 + 1
72
- real = outputs[:, :dim, :]
73
- imag = outputs[:, dim:, :]
74
  mags = torch.sqrt(real**2 + imag**2)
75
- phase = torch.atan2(imag, real)
76
- return mags, phase
 
 
 
 
 
77
 
78
 
79
  class ConviSTFT(nn.Module):
@@ -83,7 +90,6 @@ class ConviSTFT(nn.Module):
83
  hop_size: int,
84
  nfft: int = None,
85
  win_type: str = "hamming",
86
- feature_type: str = "real",
87
  requires_grad: bool = False):
88
  super(ConviSTFT, self).__init__()
89
  if nfft is None:
@@ -100,45 +106,41 @@ class ConviSTFT(nn.Module):
100
 
101
  self.stride = hop_size
102
  self.dim = self.nfft
103
- self.feature_type = feature_type
104
 
105
  self.register_buffer("window", window)
106
  self.register_buffer("enframe", torch.eye(win_size)[:, None, :])
107
 
108
  def forward(self,
109
- inputs: torch.Tensor,
110
- phase: torch.Tensor = None):
111
  """
112
- :param inputs: torch.Tensor, shape: [b, n+2, t] (complex spec) or [b, n//2+1, t] (mags)
113
- :param phase: torch.Tensor, shape: [b, n//2+1, t]
114
  :return:
115
  """
116
- if phase is not None:
117
- real = inputs * torch.cos(phase)
118
- imag = inputs * torch.sin(phase)
119
- inputs = torch.cat([real, imag], 1)
120
- outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
121
 
122
  # this is from torch-stft: https://github.com/pseeth/torch-stft
123
- t = self.window.repeat(1, 1, inputs.size(-1))**2
124
  coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
125
- outputs = outputs / (coff + 1e-8)
126
- return outputs
127
 
128
 
129
  def main():
130
- stft = ConvSTFT(win_size=512, hop_size=200, feature_type="complex")
131
- istft = ConviSTFT(win_size=512, hop_size=200, feature_type="complex")
132
 
133
  mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32)
134
 
135
  spec = stft.forward(mixture)
136
  # shape: [batch_size, freq_bins, time_steps]
137
- print(spec.shape)
138
 
139
  waveform = istft.forward(spec)
140
  # shape: [batch_size, channels, num_samples]
141
- print(waveform.shape)
142
 
143
  return
144
 
 
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
  from scipy.signal import get_window
11
+ from sympy.physics.units import power
12
 
13
 
14
  def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False):
 
41
  win_size: int,
42
  hop_size: int,
43
  win_type: str = "hamming",
44
+ power: int = None,
45
  requires_grad: bool = False):
46
  super(ConvSTFT, self).__init__()
47
 
 
58
 
59
  self.stride = hop_size
60
  self.dim = self.nfft
61
+ self.power = power
62
 
63
  def forward(self, inputs: torch.Tensor):
64
  if inputs.dim() == 2:
65
  inputs = torch.unsqueeze(inputs, 1)
66
 
67
+ matrix = F.conv1d(inputs, self.weight, stride=self.stride)
68
+ dim = self.dim // 2 + 1
69
+ real = matrix[:, :dim, :]
70
+ imag = matrix[:, dim:, :]
71
+ spec = torch.complex(real, imag)
72
 
73
+ if self.power is None:
74
+ return spec
75
+ elif self.power == 1:
 
 
 
76
  mags = torch.sqrt(real**2 + imag**2)
77
+ # phase = torch.atan2(imag, real)
78
+ return mags
79
+ elif self.power == 2:
80
+ power = real**2 + imag**2
81
+ return power
82
+ else:
83
+ raise AssertionError
84
 
85
 
86
  class ConviSTFT(nn.Module):
 
90
  hop_size: int,
91
  nfft: int = None,
92
  win_type: str = "hamming",
 
93
  requires_grad: bool = False):
94
  super(ConviSTFT, self).__init__()
95
  if nfft is None:
 
106
 
107
  self.stride = hop_size
108
  self.dim = self.nfft
 
109
 
110
  self.register_buffer("window", window)
111
  self.register_buffer("enframe", torch.eye(win_size)[:, None, :])
112
 
113
  def forward(self,
114
+ inputs: torch.Tensor):
 
115
  """
116
+ :param inputs: torch.Tensor, shape: [b, f, t]
 
117
  :return:
118
  """
119
+ inputs = torch.view_as_real(inputs)
120
+ matrix = torch.concat(tensors=[inputs[..., 0], inputs[..., 1]], dim=1)
121
+
122
+ waveform = F.conv_transpose1d(matrix, self.weight, stride=self.stride)
 
123
 
124
  # this is from torch-stft: https://github.com/pseeth/torch-stft
125
+ t = self.window.repeat(1, 1, matrix.size(-1))**2
126
  coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
127
+ waveform = waveform / (coff + 1e-8)
128
+ return waveform
129
 
130
 
131
  def main():
132
+ stft = ConvSTFT(nfft=512, win_size=512, hop_size=200, power=None)
133
+ istft = ConviSTFT(nfft=512, win_size=512, hop_size=200)
134
 
135
  mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32)
136
 
137
  spec = stft.forward(mixture)
138
  # shape: [batch_size, freq_bins, time_steps]
139
+ print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
140
 
141
  waveform = istft.forward(spec)
142
  # shape: [batch_size, channels, num_samples]
143
+ print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
144
 
145
  return
146
 
toolbox/torchaudio/models/dfnet/modeling_dfnet.py CHANGED
@@ -13,6 +13,7 @@ import torchaudio
13
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
14
  from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
15
  from toolbox.torchaudio.models.dfnet.conv_stft import ConvSTFT, ConviSTFT
 
16
 
17
 
18
  MODEL_FILE = "model.pt"
@@ -415,8 +416,8 @@ class Encoder(nn.Module):
415
  nn.Linear(self.embedding_output_size, 1),
416
  nn.Sigmoid()
417
  )
418
- self.lsnr_scale = config.lsnr_max - config.lsnr_min
419
- self.lsnr_offset = config.lsnr_min
420
 
421
  def forward(self,
422
  feat_power: torch.Tensor,
@@ -789,7 +790,7 @@ class DfNet(nn.Module):
789
  def __init__(self, config: DfNetConfig):
790
  super(DfNet, self).__init__()
791
  self.config = config
792
- self.eps = 1e-8
793
 
794
  self.freq_bins = self.config.nfft // 2 + 1
795
 
@@ -803,7 +804,7 @@ class DfNet(nn.Module):
803
  win_size=config.win_size,
804
  hop_size=config.hop_size,
805
  win_type=config.win_type,
806
- feature_type="complex",
807
  requires_grad=False
808
  )
809
  self.istft = ConviSTFT(
@@ -811,7 +812,6 @@ class DfNet(nn.Module):
811
  win_size=config.win_size,
812
  hop_size=config.hop_size,
813
  win_type=config.win_type,
814
- feature_type="complex",
815
  requires_grad=False
816
  )
817
 
@@ -828,98 +828,121 @@ class DfNet(nn.Module):
828
 
829
  self.mask = Mask(use_post_filter=config.use_post_filter)
830
 
831
- def forward(self,
832
- noisy: torch.Tensor,
833
- ):
834
- if noisy.dim() == 2:
835
- noisy = torch.unsqueeze(noisy, dim=1)
836
- _, _, n_samples = noisy.shape
 
 
 
 
 
 
 
 
 
837
  remainder = (n_samples - self.win_size) % self.hop_size
838
  if remainder > 0:
839
  n_samples_pad = self.hop_size - remainder
840
- noisy = F.pad(noisy, pad=(0, n_samples_pad), mode="constant", value=0)
 
841
 
842
- # [batch_size, freq_bins * 2, time_steps]
843
- cmp_spec = self.stft.forward(noisy)
844
- # [batch_size, 1, freq_bins * 2, time_steps]
845
- cmp_spec = torch.unsqueeze(cmp_spec, 1)
 
 
 
 
 
 
 
 
846
 
847
- # [batch_size, 2, freq_bins, time_steps]
848
- cmp_spec = torch.cat([
849
- cmp_spec[:, :, :self.freq_bins, :],
850
- cmp_spec[:, :, self.freq_bins:, :],
851
- ], dim=1)
852
- # n//2+1 -> n//2; 257 -> 256
 
853
  cmp_spec = cmp_spec[:, :, :-1, :]
 
 
854
 
855
  spec = torch.unsqueeze(cmp_spec, dim=4)
856
- # [batch_size, 2, freq_bins, time_steps, 1]
857
  spec = spec.permute(0, 4, 3, 2, 1)
858
- # spec shape: [batch_size, 1, time_steps, freq_bins, 2]
859
 
860
  feat_power = torch.sum(torch.square(spec), dim=-1)
861
- # feat_power shape: [batch_size, 1, time_steps, spec_bins]
862
 
863
  feat_spec = torch.transpose(cmp_spec, dim0=2, dim1=3)
864
- # feat_spec shape: [batch_size, 2, time_steps, freq_bins]
865
  feat_spec = feat_spec[..., :self.df_decoder.df_bins]
866
- # feat_spec shape: [batch_size, 2, time_steps, df_bins]
867
 
868
  e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
869
 
870
  mask = self.decoder.forward(emb, e3, e2, e1, e0)
871
- # mask shape: [batch_size, 1, time_steps, spec_bins]
872
  if torch.any(mask > 1) or torch.any(mask < 0):
873
  raise AssertionError
874
 
875
  spec_m = self.mask.forward(spec, mask)
876
 
877
- # lsnr shape: [batch_size, time_steps, 1]
878
  lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
879
- # lsnr shape: [batch_size, 1, time_steps]
880
 
881
  df_coefs = self.df_decoder.forward(emb, c0)
882
  df_coefs = self.df_out_transform(df_coefs)
883
- # df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
884
 
885
  spec_e = self.df_op.forward(spec.clone(), df_coefs)
886
- # est_spec shape: [batch_size, 1, time_steps, spec_bins, 2]
887
 
888
  spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
889
 
890
  spec_e = torch.squeeze(spec_e, dim=1)
891
  spec_e = spec_e.permute(0, 2, 1, 3)
892
- # spec_e shape: [batch_size, spec_bins, time_steps, 2]
893
 
894
  mask = torch.squeeze(mask, dim=1)
895
  mask = mask.permute(0, 2, 1)
896
- # mask shape: [b, 256, t]
897
  est_mask = self.mask_transfer(mask)
898
- # est_mask shape: [b, 257, t]
899
 
900
- # spec_e shape: [b, 256, t, 2]
901
  est_spec = self.spec_transfer(spec_e)
902
- # est_spec shape: [b, 257*2, t]
 
903
  est_wav = self.istft.forward(est_spec)
904
  est_wav = torch.squeeze(est_wav, dim=1)
905
  est_wav = est_wav[:, :n_samples]
906
  # est_wav shape: [b, n_samples]
 
907
  return est_spec, est_wav, est_mask, lsnr
908
 
909
  def spec_transfer(self, spec_e: torch.Tensor) -> torch.Tensor:
910
- # spec_e shape: [b, 256, t, 2]
911
  b, _, t, _ = spec_e.shape
912
- est_spec = torch.cat(tensors=[
913
- torch.concat(tensors=[
914
  spec_e[..., 0],
915
  torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
916
  ], dim=1),
917
- torch.concat(tensors=[
918
  spec_e[..., 1],
919
  torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
920
  ], dim=1),
921
- ], dim=1)
922
- # est_spec shape: [b, 257*2, t]
923
  return est_spec
924
 
925
  def mask_transfer(self, mask: torch.Tensor) -> torch.Tensor:
@@ -934,29 +957,58 @@ class DfNet(nn.Module):
934
 
935
  def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
936
  """
937
-
938
- :param est_mask: torch.Tensor, shape: [b, n+2, t]
939
  :param clean:
940
  :param noisy:
941
  :return:
942
  """
943
- clean_stft = self.stft(clean)
944
- clean_re = clean_stft[:, :self.freq_bins, :]
945
- clean_im = clean_stft[:, self.freq_bins:, :]
946
- clean_power = clean_re ** 2 + clean_im ** 2
 
 
947
 
948
- noisy_stft = self.stft(noisy)
949
- noisy_re = noisy_stft[:, :self.freq_bins, :]
950
- noisy_im = noisy_stft[:, self.freq_bins:, :]
951
- noisy_power = noisy_re ** 2 + noisy_im ** 2
952
 
953
- speech_irm = clean_power / (noisy_power + self.eps)
954
- # speech_irm = torch.pow(speech_irm, self.irm_beta)
955
 
956
- loss = F.mse_loss(est_mask, speech_irm)
 
 
957
 
958
  return loss
959
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
960
 
961
  class DfNetPretrainedModel(DfNet):
962
  def __init__(self,
@@ -1011,8 +1063,12 @@ def main():
1011
 
1012
  noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
1013
 
1014
- output = model.forward(noisy)
1015
- print(output[1].shape)
 
 
 
 
1016
  return
1017
 
1018
 
 
13
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
14
  from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
15
  from toolbox.torchaudio.models.dfnet.conv_stft import ConvSTFT, ConviSTFT
16
+ from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
17
 
18
 
19
  MODEL_FILE = "model.pt"
 
416
  nn.Linear(self.embedding_output_size, 1),
417
  nn.Sigmoid()
418
  )
419
+ self.lsnr_scale = config.max_local_snr - config.min_local_snr
420
+ self.lsnr_offset = config.min_local_snr
421
 
422
  def forward(self,
423
  feat_power: torch.Tensor,
 
790
  def __init__(self, config: DfNetConfig):
791
  super(DfNet, self).__init__()
792
  self.config = config
793
+ self.eps = 1e-12
794
 
795
  self.freq_bins = self.config.nfft // 2 + 1
796
 
 
804
  win_size=config.win_size,
805
  hop_size=config.hop_size,
806
  win_type=config.win_type,
807
+ power=None,
808
  requires_grad=False
809
  )
810
  self.istft = ConviSTFT(
 
812
  win_size=config.win_size,
813
  hop_size=config.hop_size,
814
  win_type=config.win_type,
 
815
  requires_grad=False
816
  )
817
 
 
828
 
829
  self.mask = Mask(use_post_filter=config.use_post_filter)
830
 
831
+ self.lsnr_fn = LocalSnrTarget(
832
+ sample_rate=config.sample_rate,
833
+ nfft=config.nfft,
834
+ win_size=config.win_size,
835
+ hop_size=config.hop_size,
836
+ n_frame=config.n_frame,
837
+ min_local_snr=config.min_local_snr,
838
+ max_local_snr=config.max_local_snr,
839
+ db=True,
840
+ )
841
+
842
+ def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor:
843
+ if signal.dim() == 2:
844
+ signal = torch.unsqueeze(signal, dim=1)
845
+ _, _, n_samples = signal.shape
846
  remainder = (n_samples - self.win_size) % self.hop_size
847
  if remainder > 0:
848
  n_samples_pad = self.hop_size - remainder
849
+ signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
850
+ return signal, n_samples
851
 
852
+ def forward(self,
853
+ noisy: torch.Tensor,
854
+ ):
855
+ """
856
+ :param noisy:
857
+ :return:
858
+ est_spec: shape: [b, 257*2, t]
859
+ est_wav: shape: [b, num_samples]
860
+ est_mask: shape: [b, 257, t]
861
+ lsnr: shape: [b, 1, t]
862
+ """
863
+ noisy, n_samples = self.signal_prepare(noisy)
864
 
865
+ # noisy shape: [b, num_samples_pad]
866
+ cmp_spec = self.stft.forward(noisy)
867
+ # cmp_spec shape: [b, f, t], torch.complex64
868
+ cmp_spec = torch.view_as_real(cmp_spec)
869
+ # cmp_spec shape: [b, f, t, 2]
870
+ cmp_spec = cmp_spec.permute(0, 3, 1, 2)
871
+ # cmp_spec shape: [b, 2, f, t]
872
  cmp_spec = cmp_spec[:, :, :-1, :]
873
+ # cmp_spec shape: [b, 2, spec_bins, t]
874
+ # n//2+1 -> n//2; 257 -> 256
875
 
876
  spec = torch.unsqueeze(cmp_spec, dim=4)
877
+ # spec shape: [b, 2, spec_bins, t, 1]
878
  spec = spec.permute(0, 4, 3, 2, 1)
879
+ # spec shape: [b, 1, t, spec_bins, 2]
880
 
881
  feat_power = torch.sum(torch.square(spec), dim=-1)
882
+ # feat_power shape: [b, 1, t, spec_bins]
883
 
884
  feat_spec = torch.transpose(cmp_spec, dim0=2, dim1=3)
885
+ # feat_spec shape: [b, 2, t, spec_bins]
886
  feat_spec = feat_spec[..., :self.df_decoder.df_bins]
887
+ # feat_spec shape: [b, 2, t, df_bins]
888
 
889
  e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
890
 
891
  mask = self.decoder.forward(emb, e3, e2, e1, e0)
892
+ # mask shape: [b, 1, t, spec_bins]
893
  if torch.any(mask > 1) or torch.any(mask < 0):
894
  raise AssertionError
895
 
896
  spec_m = self.mask.forward(spec, mask)
897
 
898
+ # lsnr shape: [b, t, 1]
899
  lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
900
+ # lsnr shape: [b, 1, t]
901
 
902
  df_coefs = self.df_decoder.forward(emb, c0)
903
  df_coefs = self.df_out_transform(df_coefs)
904
+ # df_coefs shape: [b, df_order, t, df_bins, 2]
905
 
906
  spec_e = self.df_op.forward(spec.clone(), df_coefs)
907
+ # est_spec shape: [b, 1, t, spec_bins, 2]
908
 
909
  spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
910
 
911
  spec_e = torch.squeeze(spec_e, dim=1)
912
  spec_e = spec_e.permute(0, 2, 1, 3)
913
+ # spec_e shape: [b, spec_bins, t, 2]
914
 
915
  mask = torch.squeeze(mask, dim=1)
916
  mask = mask.permute(0, 2, 1)
917
+ # mask shape: [b, spec_bins, t]
918
  est_mask = self.mask_transfer(mask)
919
+ # est_mask shape: [b, f, t]
920
 
921
+ # spec_e shape: [b, spec_bins, t, 2]
922
  est_spec = self.spec_transfer(spec_e)
923
+ # est_spec shape: [b, f, t], torch.complex64
924
+
925
  est_wav = self.istft.forward(est_spec)
926
  est_wav = torch.squeeze(est_wav, dim=1)
927
  est_wav = est_wav[:, :n_samples]
928
  # est_wav shape: [b, n_samples]
929
+
930
  return est_spec, est_wav, est_mask, lsnr
931
 
932
  def spec_transfer(self, spec_e: torch.Tensor) -> torch.Tensor:
933
+ # spec_e shape: [b, spec_bins, t, 2]
934
  b, _, t, _ = spec_e.shape
935
+ est_spec = torch.complex(
936
+ real=torch.concat(tensors=[
937
  spec_e[..., 0],
938
  torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
939
  ], dim=1),
940
+ imag=torch.concat(tensors=[
941
  spec_e[..., 1],
942
  torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
943
  ], dim=1),
944
+ )
945
+ # est_spec shape: [b, f, t]
946
  return est_spec
947
 
948
  def mask_transfer(self, mask: torch.Tensor) -> torch.Tensor:
 
957
 
958
  def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
959
  """
960
+ :param est_mask: torch.Tensor, shape: [b, 257, t]
 
961
  :param clean:
962
  :param noisy:
963
  :return:
964
  """
965
+ if noisy.shape != clean.shape:
966
+ raise AssertionError("Input signals must have the same shape")
967
+ noise = noisy - clean
968
+
969
+ clean, _ = self.signal_prepare(clean)
970
+ noise, _ = self.signal_prepare(noise)
971
 
972
+ stft_clean = self.stft.forward(clean)
973
+ mag_clean = torch.abs(stft_clean)
 
 
974
 
975
+ stft_noise = self.stft.forward(noise)
976
+ mag_noise = torch.abs(stft_noise)
977
 
978
+ gth_irm_mask = (mag_clean / (mag_clean + mag_noise + self.eps)).clamp(0, 1)
979
+
980
+ loss = F.l1_loss(gth_irm_mask, est_mask, reduction="mean")
981
 
982
  return loss
983
 
984
+ def lsnr_loss_fn(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
985
+ if noisy.shape != clean.shape:
986
+ raise AssertionError("Input signals must have the same shape")
987
+ noise = noisy - clean
988
+
989
+ clean, _ = self.signal_prepare(clean)
990
+ noise, _ = self.signal_prepare(noise)
991
+
992
+ stft_clean = self.stft.forward(clean)
993
+ stft_noise = self.stft.forward(noise)
994
+ # shape: [b, f, t]
995
+ stft_clean = torch.transpose(stft_clean, dim0=1, dim1=2)
996
+ stft_noise = torch.transpose(stft_noise, dim0=1, dim1=2)
997
+ # shape: [b, t, f]
998
+ stft_clean = torch.unsqueeze(stft_clean, dim=1)
999
+ stft_noise = torch.unsqueeze(stft_noise, dim=1)
1000
+ # shape: [b, 1, t, f]
1001
+
1002
+ # lsnr shape: [b, 1, t]
1003
+ lsnr = lsnr.squeeze(1)
1004
+ # lsnr shape: [b, t]
1005
+
1006
+ lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
1007
+ # lsnr_gth shape: [b, t]
1008
+
1009
+ loss = F.mse_loss(lsnr, lsnr_gth)
1010
+ return loss
1011
+
1012
 
1013
  class DfNetPretrainedModel(DfNet):
1014
  def __init__(self,
 
1063
 
1064
  noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
1065
 
1066
+ est_spec, est_wav, est_mask, lsnr = model.forward(noisy)
1067
+ print(f"est_spec.shape: {est_spec.shape}")
1068
+ print(f"est_wav.shape: {est_wav.shape}")
1069
+ print(f"est_mask.shape: {est_mask.shape}")
1070
+ print(f"lsnr.shape: {lsnr.shape}")
1071
+
1072
  return
1073
 
1074
 
toolbox/torchaudio/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/modules/local_snr_target.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/modules.py#L816
5
+ """
6
+ from typing import Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+ import torchaudio
12
+
13
+
14
+ def local_energy(spec: torch.Tensor, n_frame: int, device: torch.device) -> torch.Tensor:
15
+ if n_frame % 2 == 0:
16
+ n_frame += 1
17
+ n_frame_half = n_frame // 2
18
+
19
+ # spec shape: [b, c, t, f, 2]
20
+ spec = F.pad(spec.pow(2).sum(-1).sum(-1), (n_frame_half, n_frame_half, 0, 0))
21
+ # spec shape: [b, c, t-pad]
22
+
23
+ weight = torch.hann_window(n_frame, device=device, dtype=spec.dtype)
24
+ # w shape: [n_frame]
25
+
26
+ spec = spec.unfold(-1, size=n_frame, step=1) * weight
27
+ # x shape: [b, c, t, n_frame]
28
+
29
+ result = torch.sum(spec, dim=-1).div(n_frame)
30
+ # result shape: [b, c, t]
31
+ return result
32
+
33
+
34
+ def local_snr(spec_clean: torch.Tensor,
35
+ spec_noise: torch.Tensor,
36
+ n_frame: int = 5,
37
+ db: bool = False,
38
+ eps: float = 1e-12,
39
+ ):
40
+ # [b, c, t, f]
41
+ spec_clean = torch.view_as_real(spec_clean)
42
+ spec_noise = torch.view_as_real(spec_noise)
43
+ # [b, c, t, f, 2]
44
+
45
+ energy_clean = local_energy(spec_clean, n_frame=n_frame, device=spec_clean.device)
46
+ energy_noise = local_energy(spec_noise, n_frame=n_frame, device=spec_noise.device)
47
+ # [b, c, t]
48
+
49
+ snr = energy_clean / energy_noise.clamp_min(eps)
50
+ # snr shape: [b, c, t]
51
+
52
+ if db:
53
+ snr = snr.clamp_min(eps).log10().mul(10)
54
+ return snr, energy_clean, energy_noise
55
+
56
+
57
+ class LocalSnrTarget(nn.Module):
58
+ def __init__(self,
59
+ sample_rate: int = 8000,
60
+ nfft: int = 512,
61
+ win_size: int = 512,
62
+ hop_size: int = 256,
63
+
64
+ n_frame: int = 3,
65
+
66
+ min_local_snr: int = -15,
67
+ max_local_snr: int = 30,
68
+
69
+ db: bool = True,
70
+ ):
71
+ super().__init__()
72
+ self.sample_rate = sample_rate
73
+ self.nfft = nfft
74
+ self.win_size = win_size
75
+ self.hop_size = hop_size
76
+
77
+ self.n_frame = n_frame
78
+
79
+ self.min_local_snr = min_local_snr
80
+ self.max_local_snr = max_local_snr
81
+
82
+ self.db = db
83
+
84
+ def forward(self,
85
+ spec_clean: torch.Tensor,
86
+ spec_noise: torch.Tensor,
87
+ ) -> torch.Tensor:
88
+ """
89
+
90
+ :param spec_clean: torch.complex, shape: [b, c, t, f]
91
+ :param spec_noise: torch.complex, shape: [b, c, t, f]
92
+ :return: lsnr, shape: [b, t]
93
+ """
94
+
95
+ lsnr, _, _ = local_snr(
96
+ spec_clean=spec_clean,
97
+ spec_noise=spec_noise,
98
+ n_frame=self.n_frame,
99
+ db=self.db,
100
+ )
101
+ # lsnr shape: [b, c, t]
102
+ lsnr = lsnr.clamp(self.min_local_snr, self.max_local_snr).squeeze(1)
103
+ # lsnr shape: [b, t]
104
+ return lsnr
105
+
106
+
107
+ def main():
108
+ sample_rate = 8000
109
+ nfft = 512
110
+ win_size = 512
111
+ hop_size = 256
112
+ window_fn = "hamming"
113
+
114
+ transform = torchaudio.transforms.Spectrogram(
115
+ n_fft=nfft,
116
+ win_length=win_size,
117
+ hop_length=hop_size,
118
+ power=None,
119
+ window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
120
+ )
121
+
122
+ noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
123
+
124
+ spec = transform.forward(noisy)
125
+ spec = spec.permute(0, 2, 1)
126
+ spec = torch.unsqueeze(spec, dim=1)
127
+ print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
128
+
129
+ # [b, c, t, f]
130
+ # spec = torch.view_as_real(spec)
131
+ # [b, c, t, f, 2]
132
+
133
+ local = LocalSnrTarget(
134
+ sample_rate=sample_rate,
135
+ nfft=nfft,
136
+ win_size=win_size,
137
+ hop_size=hop_size,
138
+ n_frame=5,
139
+ min_local_snr=-15,
140
+ max_local_snr=30,
141
+ db=True,
142
+ )
143
+ lsnr_target = local.forward(spec, spec)
144
+ print(f"lsnr_target.shape: {lsnr_target.shape}")
145
+ return
146
+
147
+
148
+ if __name__ == "__main__":
149
+ main()