HoneyTian commited on
Commit
9192cea
·
1 Parent(s): f239cae
examples/dfnet/step_2_train_model.py CHANGED
@@ -187,18 +187,12 @@ def main():
187
  if last_step_idx != -1:
188
  logger.info(f"resume from steps-{last_step_idx}.")
189
  model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
190
- optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
191
 
192
  logger.info(f"load state dict for model.")
193
  with open(model_pt.as_posix(), "rb") as f:
194
  state_dict = torch.load(f, map_location="cpu", weights_only=True)
195
  model.load_state_dict(state_dict, strict=True)
196
 
197
- logger.info(f"load state dict for optimizer.")
198
- with open(optimizer_pth.as_posix(), "rb") as f:
199
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
200
- optimizer.load_state_dict(state_dict)
201
-
202
  if config.lr_scheduler == "CosineAnnealingLR":
203
  lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
204
  optimizer,
@@ -270,14 +264,14 @@ def main():
270
  clean_audios: torch.Tensor = clean_audios.to(device)
271
  noisy_audios: torch.Tensor = noisy_audios.to(device)
272
 
273
- est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
274
 
275
  mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
276
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
277
  mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
278
  lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
279
 
280
- loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 1.0 * lsnr_loss
281
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
282
  logger.info(f"find nan or inf in loss.")
283
  continue
@@ -341,14 +335,14 @@ def main():
341
  clean_audios: torch.Tensor = clean_audios.to(device)
342
  noisy_audios: torch.Tensor = noisy_audios.to(device)
343
 
344
- est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
345
 
346
  mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
347
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
348
  mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
349
  lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
350
 
351
- loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 1.0 * lsnr_loss
352
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
353
  logger.info(f"find nan or inf in loss.")
354
  continue
@@ -410,9 +404,6 @@ def main():
410
  model_to_delete: Path = model_list.pop(0)
411
  shutil.rmtree(model_to_delete.as_posix())
412
 
413
- # save optim
414
- torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
415
-
416
  # save metric
417
  if best_metric is None:
418
  best_epoch_idx = epoch_idx
 
187
  if last_step_idx != -1:
188
  logger.info(f"resume from steps-{last_step_idx}.")
189
  model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
 
190
 
191
  logger.info(f"load state dict for model.")
192
  with open(model_pt.as_posix(), "rb") as f:
193
  state_dict = torch.load(f, map_location="cpu", weights_only=True)
194
  model.load_state_dict(state_dict, strict=True)
195
 
 
 
 
 
 
196
  if config.lr_scheduler == "CosineAnnealingLR":
197
  lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
198
  optimizer,
 
264
  clean_audios: torch.Tensor = clean_audios.to(device)
265
  noisy_audios: torch.Tensor = noisy_audios.to(device)
266
 
267
+ est_spec, est_wav, est_mask, lsnr, erb_encoder_h = model.forward(noisy_audios)
268
 
269
  mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
270
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
271
  mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
272
  lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
273
 
274
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.3 * lsnr_loss
275
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
276
  logger.info(f"find nan or inf in loss.")
277
  continue
 
335
  clean_audios: torch.Tensor = clean_audios.to(device)
336
  noisy_audios: torch.Tensor = noisy_audios.to(device)
337
 
338
+ est_spec, est_wav, est_mask, lsnr, erb_encoder_h = model.forward(noisy_audios)
339
 
340
  mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
341
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
342
  mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
343
  lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
344
 
345
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.3 * lsnr_loss
346
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
347
  logger.info(f"find nan or inf in loss.")
348
  continue
 
404
  model_to_delete: Path = model_list.pop(0)
405
  shutil.rmtree(model_to_delete.as_posix())
406
 
 
 
 
407
  # save metric
408
  if best_metric is None:
409
  best_epoch_idx = epoch_idx
examples/dfnet/yaml/config.yaml CHANGED
@@ -68,7 +68,7 @@ seed: 1234
68
 
69
  num_workers: 8
70
  batch_size: 64
71
- eval_steps: 20000
72
 
73
  # runtime
74
  use_post_filter: true
 
68
 
69
  num_workers: 8
70
  batch_size: 64
71
+ eval_steps: 10000
72
 
73
  # runtime
74
  use_post_filter: true
examples/dtln/yaml/config.yaml CHANGED
@@ -24,6 +24,6 @@ max_epochs: 100
24
  clip_grad_norm: 10.0
25
  seed: 1234
26
 
27
- batch_size: 128
28
  num_workers: 4
29
- eval_steps: 25000
 
24
  clip_grad_norm: 10.0
25
  seed: 1234
26
 
27
+ batch_size: 64
28
  num_workers: 4
29
+ eval_steps: 15000
main.py CHANGED
@@ -62,10 +62,10 @@ def shell(cmd: str):
62
 
63
 
64
  denoise_engines = {
65
- "mpnet-nx-speech": {
66
- "infer_cls": InferenceMPNet,
67
  "kwargs": {
68
- "pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-nx-speech.zip").as_posix()
69
  }
70
  },
71
  "frcrn-dns3": {
@@ -74,6 +74,12 @@ denoise_engines = {
74
  "pretrained_model_path_or_zip_file": (project_path / "trained_models/frcrn-dns3.zip").as_posix()
75
  }
76
  },
 
 
 
 
 
 
77
  }
78
 
79
 
 
62
 
63
 
64
  denoise_engines = {
65
+ "dfnet-nx-dns3": {
66
+ "infer_cls": InferenceFRCRN,
67
  "kwargs": {
68
+ "pretrained_model_path_or_zip_file": (project_path / "trained_models/dfnet-nx-dns3.zip").as_posix()
69
  }
70
  },
71
  "frcrn-dns3": {
 
74
  "pretrained_model_path_or_zip_file": (project_path / "trained_models/frcrn-dns3.zip").as_posix()
75
  }
76
  },
77
+ "mpnet-nx-speech": {
78
+ "infer_cls": InferenceMPNet,
79
+ "kwargs": {
80
+ "pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-nx-speech.zip").as_posix()
81
+ }
82
+ },
83
  }
84
 
85
 
toolbox/torchaudio/models/dfnet/inference_dfnet.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import logging
4
+ from pathlib import Path
5
+ import shutil
6
+ import tempfile, time
7
+ import zipfile
8
+
9
+ import librosa
10
+ import numpy as np
11
+ import torch
12
+ import torchaudio
13
+
14
+ torch.set_num_threads(1)
15
+
16
+ from project_settings import project_path
17
+ from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
18
+ from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNetPretrainedModel, MODEL_FILE
19
+
20
+ logger = logging.getLogger("toolbox")
21
+
22
+
23
+ class InferenceDfNet(object):
24
+ def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
25
+ self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
26
+ self.device = torch.device(device)
27
+
28
+ logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
29
+ config, model = self.load_models(self.pretrained_model_path_or_zip_file)
30
+ logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
31
+
32
+ self.config = config
33
+ self.model = model
34
+ self.model.to(device)
35
+ self.model.eval()
36
+
37
+ def load_models(self, model_path: str):
38
+ model_path = Path(model_path)
39
+ if model_path.name.endswith(".zip"):
40
+ with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
41
+ out_root = Path(tempfile.gettempdir()) / "nx_denoise"
42
+ out_root.mkdir(parents=True, exist_ok=True)
43
+ f_zip.extractall(path=out_root)
44
+ model_path = out_root / model_path.stem
45
+
46
+ config = DfNetConfig.from_pretrained(
47
+ pretrained_model_name_or_path=model_path.as_posix(),
48
+ )
49
+ model = DfNetPretrainedModel.from_pretrained(
50
+ pretrained_model_name_or_path=model_path.as_posix(),
51
+ )
52
+ model.to(self.device)
53
+ model.eval()
54
+
55
+ shutil.rmtree(model_path)
56
+ return config, model
57
+
58
+ def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray:
59
+ noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
60
+ noisy_audio = noisy_audio.unsqueeze(dim=0)
61
+
62
+ # noisy_audio shape: [batch_size, n_samples]
63
+ enhanced_audio = self.enhancement_by_tensor(noisy_audio)
64
+ # enhanced_audio shape: [channels, num_samples]
65
+ enhanced_audio = enhanced_audio[0]
66
+ # enhanced_audio shape: [num_samples]
67
+ return enhanced_audio.cpu().numpy()
68
+
69
+ def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
70
+ if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
71
+ raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
72
+
73
+ # noisy_audio shape: [batch_size, num_samples]
74
+ noisy_audios = noisy_audio.to(self.device)
75
+
76
+ with torch.no_grad():
77
+ est_spec, est_wav, est_mask, lsnr = self.model.forward(noisy_audios)
78
+
79
+ # shape: [batch_size, num_samples]
80
+ enhanced_audio = torch.unsqueeze(est_wav, dim=1)
81
+ # shape: [batch_size, 1, num_samples]
82
+
83
+ enhanced_audio = enhanced_audio[0]
84
+ # shape: [channels, num_samples]
85
+ return enhanced_audio
86
+
87
+
88
+ def main():
89
+ model_zip_file = project_path / "trained_models/dfnet-nx-dns3.zip"
90
+ infer_model = InferenceDfNet(model_zip_file)
91
+
92
+ sample_rate = 8000
93
+ noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_3.wav"
94
+ noisy_audio, sample_rate = librosa.load(
95
+ noisy_audio_file.as_posix(),
96
+ sr=sample_rate,
97
+ )
98
+ duration = librosa.get_duration(y=noisy_audio, sr=sample_rate)
99
+ # noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
100
+ noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
101
+ noisy_audio = noisy_audio.unsqueeze(dim=0)
102
+
103
+ begin = time.time()
104
+ enhanced_audio = infer_model.enhancement_by_tensor(noisy_audio)
105
+ time_cost = time.time() - begin
106
+ print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
107
+
108
+ filename = "enhanced_audio.wav"
109
+ torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
110
+
111
+ return
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()
toolbox/torchaudio/models/dfnet/yaml/config.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "dfnet"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ nfft: 512
6
+ win_size: 200
7
+ hop_size: 80
8
+
9
+ spec_bins: 256
10
+
11
+ # model
12
+ conv_channels: 64
13
+ conv_kernel_size_input:
14
+ - 3
15
+ - 3
16
+ conv_kernel_size_inner:
17
+ - 1
18
+ - 3
19
+ conv_lookahead: 0
20
+
21
+ convt_kernel_size_inner:
22
+ - 1
23
+ - 3
24
+
25
+ embedding_hidden_size: 256
26
+ encoder_combine_op: "concat"
27
+
28
+ encoder_emb_skip_op: "none"
29
+ encoder_emb_linear_groups: 16
30
+ encoder_emb_hidden_size: 256
31
+
32
+ encoder_linear_groups: 32
33
+
34
+ decoder_emb_num_layers: 3
35
+ decoder_emb_skip_op: "none"
36
+ decoder_emb_linear_groups: 16
37
+ decoder_emb_hidden_size: 256
38
+
39
+ df_decoder_hidden_size: 256
40
+ df_num_layers: 2
41
+ df_order: 5
42
+ df_bins: 96
43
+ df_gru_skip: "grouped_linear"
44
+ df_decoder_linear_groups: 16
45
+ df_pathway_kernel_size_t: 5
46
+ df_lookahead: 2
47
+
48
+ # lsnr
49
+ n_frame: 3
50
+ lsnr_max: 30
51
+ lsnr_min: -15
52
+ norm_tau: 1.
53
+
54
+ # data
55
+ min_snr_db: -10
56
+ max_snr_db: 20
57
+
58
+ # train
59
+ lr: 0.001
60
+ lr_scheduler: "CosineAnnealingLR"
61
+ lr_scheduler_kwargs:
62
+ T_max: 250000
63
+ eta_min: 0.0001
64
+
65
+ max_epochs: 100
66
+ clip_grad_norm: 10.0
67
+ seed: 1234
68
+
69
+ num_workers: 8
70
+ batch_size: 64
71
+ eval_steps: 10000
72
+
73
+ # runtime
74
+ use_post_filter: true
toolbox/torchaudio/modules/conv_stft.py CHANGED
@@ -141,6 +141,7 @@ class ConviSTFT(nn.Module):
141
  # waveform = waveform / coff
142
  return waveform
143
 
 
144
  def forward_chunk(self,
145
  spec: torch.Tensor,
146
  waveform_cache: torch.Tensor = None,
@@ -163,22 +164,14 @@ class ConviSTFT(nn.Module):
163
  overlap_size = self.win_size - self.hop_size
164
 
165
  if waveform_cache is not None:
166
- waveform_overlap = waveform_current[:, :, :overlap_size] + waveform_cache
167
- waveform_non_overlap = waveform_current[:, :, overlap_size:-self.hop_size]
168
- waveform_output = torch.cat(tensors=[waveform_overlap, waveform_non_overlap], dim=-1)
169
- new_waveform_cache = waveform_current[:, :, -self.hop_size:]
170
- else:
171
- waveform_output = waveform_current[:, :, :-self.hop_size]
172
- new_waveform_cache = waveform_current[:, :, -self.hop_size:]
173
 
174
  if coff_cache is not None:
175
- coff_overlap = coff_current[:, :, :overlap_size] + coff_cache
176
- coff_non_overlap = coff_current[:, :, overlap_size:-self.hop_size]
177
- coff_output = torch.cat(tensors=[coff_overlap, coff_non_overlap], dim=-1)
178
- new_coff_cache = coff_current[:, :, -self.hop_size:]
179
- else:
180
- coff_output = coff_current[:, :, :-self.hop_size]
181
- new_coff_cache = coff_current[:, :, -self.hop_size:]
182
 
183
  waveform_output = waveform_output / (coff_output + 1e-8)
184
  return waveform_output, new_waveform_cache, new_coff_cache
 
141
  # waveform = waveform / coff
142
  return waveform
143
 
144
+ @torch.no_grad()
145
  def forward_chunk(self,
146
  spec: torch.Tensor,
147
  waveform_cache: torch.Tensor = None,
 
164
  overlap_size = self.win_size - self.hop_size
165
 
166
  if waveform_cache is not None:
167
+ waveform_current[:, :, :overlap_size] += waveform_cache
168
+ waveform_output = waveform_current[:, :, :self.hop_size]
169
+ new_waveform_cache = waveform_current[:, :, self.hop_size:]
 
 
 
 
170
 
171
  if coff_cache is not None:
172
+ coff_current[:, :, :overlap_size] += coff_cache
173
+ coff_output = coff_current[:, :, :self.hop_size]
174
+ new_coff_cache = coff_current[:, :, self.hop_size:]
 
 
 
 
175
 
176
  waveform_output = waveform_output / (coff_output + 1e-8)
177
  return waveform_output, new_waveform_cache, new_coff_cache
toolbox/torchaudio/modules/utils/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/modules/utils/ema.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import torch.nn as nn
4
+
5
+
6
+ class ExponentialMovingAverage(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+
11
+ if __name__ == "__main__":
12
+ pass