HoneyTian commited on
Commit
9a0003a
·
1 Parent(s): 74d0273
examples/dfnet/run.sh CHANGED
@@ -3,7 +3,7 @@
3
  : <<'END'
4
 
5
 
6
- sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn \
7
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
8
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
9
 
 
3
  : <<'END'
4
 
5
 
6
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet-dns3 \
7
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
8
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
9
 
examples/dfnet/yaml/config.yaml CHANGED
@@ -51,3 +51,21 @@ df_lookahead: 2
51
 
52
  # runtime
53
  use_post_filter: true
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # runtime
53
  use_post_filter: true
54
+
55
+ # train
56
+ lr: 0.001
57
+ lr_scheduler: "CosineAnnealingLR"
58
+ lr_scheduler_kwargs:
59
+ T_max: 250000
60
+ eta_min: 0.0001
61
+
62
+ max_epochs: 100
63
+ clip_grad_norm: 10.0
64
+ seed: 1234
65
+
66
+ min_snr_db: -10
67
+ max_snr_db: 20
68
+
69
+ num_workers: 8
70
+ batch_size: 32
71
+ eval_steps: 10000
examples/frcrn/step_2_train_model.py CHANGED
@@ -1,6 +1,8 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  """
 
 
4
  FRCRN 论文中:
5
  在 WSJ0 数据集上训练了 120 个 epoch 得到 pesq 3.62, stoi 98.24, si-snr 21.33
6
 
@@ -188,17 +190,17 @@ def main():
188
  if last_step_idx != -1:
189
  logger.info(f"resume from steps-{last_step_idx}.")
190
  model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
191
- optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
192
 
193
  logger.info(f"load state dict for model.")
194
  with open(model_pt.as_posix(), "rb") as f:
195
  state_dict = torch.load(f, map_location="cpu", weights_only=True)
196
  model.load_state_dict(state_dict, strict=True)
197
 
198
- logger.info(f"load state dict for optimizer.")
199
- with open(optimizer_pth.as_posix(), "rb") as f:
200
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
201
- optimizer.load_state_dict(state_dict)
202
 
203
  if config.lr_scheduler == "CosineAnnealingLR":
204
  lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
@@ -377,15 +379,12 @@ def main():
377
  model_to_delete: Path = model_list.pop(0)
378
  shutil.rmtree(model_to_delete.as_posix())
379
 
380
- # save optim
381
- torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
382
-
383
  # save metric
384
  if best_metric is None:
385
  best_epoch_idx = epoch_idx
386
  best_step_idx = step_idx
387
  best_metric = average_pesq_score
388
- elif average_pesq_score > best_metric:
389
  # great is better.
390
  best_epoch_idx = epoch_idx
391
  best_step_idx = step_idx
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  """
4
+ https://arxiv.org/abs/2206.07293
5
+
6
  FRCRN 论文中:
7
  在 WSJ0 数据集上训练了 120 个 epoch 得到 pesq 3.62, stoi 98.24, si-snr 21.33
8
 
 
190
  if last_step_idx != -1:
191
  logger.info(f"resume from steps-{last_step_idx}.")
192
  model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
193
+ # optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
194
 
195
  logger.info(f"load state dict for model.")
196
  with open(model_pt.as_posix(), "rb") as f:
197
  state_dict = torch.load(f, map_location="cpu", weights_only=True)
198
  model.load_state_dict(state_dict, strict=True)
199
 
200
+ # logger.info(f"load state dict for optimizer.")
201
+ # with open(optimizer_pth.as_posix(), "rb") as f:
202
+ # state_dict = torch.load(f, map_location="cpu", weights_only=True)
203
+ # optimizer.load_state_dict(state_dict)
204
 
205
  if config.lr_scheduler == "CosineAnnealingLR":
206
  lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
 
379
  model_to_delete: Path = model_list.pop(0)
380
  shutil.rmtree(model_to_delete.as_posix())
381
 
 
 
 
382
  # save metric
383
  if best_metric is None:
384
  best_epoch_idx = epoch_idx
385
  best_step_idx = step_idx
386
  best_metric = average_pesq_score
387
+ elif average_pesq_score >= best_metric:
388
  # great is better.
389
  best_epoch_idx = epoch_idx
390
  best_step_idx = step_idx
examples/mpnet/yaml/config.yaml CHANGED
@@ -25,3 +25,6 @@ dist_config:
25
  dist_backend: nccl
26
  dist_url: tcp://localhost:54321
27
  world_size: 1
 
 
 
 
25
  dist_backend: nccl
26
  dist_url: tcp://localhost:54321
27
  world_size: 1
28
+
29
+ discriminator_dim: 32
30
+ discriminator_in_channel: 2
main.py CHANGED
@@ -16,6 +16,7 @@ import log
16
  from project_settings import environment, project_path, log_directory
17
  from toolbox.os.command import Command
18
  from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet
 
19
 
20
  log.setup_size_rotating(log_directory=log_directory)
21
 
@@ -93,6 +94,12 @@ denoise_engines = {
93
  "pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-aishell-11-epoch.zip").as_posix()
94
  }
95
  },
 
 
 
 
 
 
96
  }
97
 
98
 
 
16
  from project_settings import environment, project_path, log_directory
17
  from toolbox.os.command import Command
18
  from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet
19
+ from toolbox.torchaudio.models.frcrn.inference_frcrn import InferenceFRCRN
20
 
21
  log.setup_size_rotating(log_directory=log_directory)
22
 
 
94
  "pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-aishell-11-epoch.zip").as_posix()
95
  }
96
  },
97
+ "frcrn-dns3": {
98
+ "infer_cls": InferenceFRCRN,
99
+ "kwargs": {
100
+ "pretrained_model_path_or_zip_file": (project_path / "trained_models/frcrn-dns3-220k-steps.zip").as_posix()
101
+ }
102
+ },
103
  }
104
 
105
 
toolbox/torchaudio/models/dfnet/configuration_dfnet.py CHANGED
@@ -50,6 +50,22 @@ class DfNetConfig(PretrainedConfig):
50
  df_lookahead: int = 2,
51
 
52
  use_post_filter: bool = False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  **kwargs
54
  ):
55
  super(DfNetConfig, self).__init__(**kwargs)
@@ -104,6 +120,22 @@ class DfNetConfig(PretrainedConfig):
104
  # runtime
105
  self.use_post_filter = use_post_filter
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  if __name__ == "__main__":
109
  pass
 
50
  df_lookahead: int = 2,
51
 
52
  use_post_filter: bool = False,
53
+
54
+ lr: float = 0.001,
55
+ lr_scheduler: str = "CosineAnnealingLR",
56
+ lr_scheduler_kwargs: dict = None,
57
+
58
+ max_epochs: int = 100,
59
+ clip_grad_norm: float = 10.,
60
+ seed: int = 1234,
61
+
62
+ min_snr_db: float = -10,
63
+ max_snr_db: float = 20,
64
+
65
+ num_workers: int = 4,
66
+ batch_size: int = 4,
67
+ eval_steps: int = 25000,
68
+
69
  **kwargs
70
  ):
71
  super(DfNetConfig, self).__init__(**kwargs)
 
120
  # runtime
121
  self.use_post_filter = use_post_filter
122
 
123
+ #
124
+ self.lr = lr
125
+ self.lr_scheduler = lr_scheduler
126
+ self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
127
+
128
+ self.max_epochs = max_epochs
129
+ self.clip_grad_norm = clip_grad_norm
130
+ self.seed = seed
131
+
132
+ self.min_snr_db = min_snr_db
133
+ self.max_snr_db = max_snr_db
134
+
135
+ self.num_workers = num_workers
136
+ self.batch_size = batch_size
137
+ self.eval_steps = eval_steps
138
+
139
 
140
  if __name__ == "__main__":
141
  pass
toolbox/torchaudio/models/frcrn/inference_frcrn.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import logging
4
+ from pathlib import Path
5
+ import shutil
6
+ import tempfile, time
7
+ import zipfile
8
+
9
+ import librosa
10
+ import numpy as np
11
+ import torch
12
+ import torchaudio
13
+
14
+ torch.set_num_threads(1)
15
+
16
+ from project_settings import project_path
17
+ from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig
18
+ from toolbox.torchaudio.models.frcrn.modeling_frcrn import FRCRNPretrainedModel, MODEL_FILE
19
+
20
+ logger = logging.getLogger("toolbox")
21
+
22
+
23
+ class InferenceFRCRN(object):
24
+ def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
25
+ self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
26
+ self.device = torch.device(device)
27
+
28
+ logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
29
+ config, model = self.load_models(self.pretrained_model_path_or_zip_file)
30
+ logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
31
+
32
+ self.config = config
33
+ self.model = model
34
+ self.model.to(device)
35
+ self.model.eval()
36
+
37
+ def load_models(self, model_path: str):
38
+ model_path = Path(model_path)
39
+ if model_path.name.endswith(".zip"):
40
+ with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
41
+ out_root = Path(tempfile.gettempdir()) / "nx_denoise"
42
+ out_root.mkdir(parents=True, exist_ok=True)
43
+ f_zip.extractall(path=out_root)
44
+ model_path = out_root / model_path.stem
45
+
46
+ config = FRCRNConfig.from_pretrained(
47
+ pretrained_model_name_or_path=model_path.as_posix(),
48
+ )
49
+ model = FRCRNPretrainedModel.from_pretrained(
50
+ pretrained_model_name_or_path=model_path.as_posix(),
51
+ )
52
+ model.to(self.device)
53
+ model.eval()
54
+
55
+ shutil.rmtree(model_path)
56
+ return config, model
57
+
58
+ def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray:
59
+ noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
60
+ noisy_audio = noisy_audio.unsqueeze(dim=0)
61
+
62
+ # noisy_audio shape: [batch_size, n_samples]
63
+ enhanced_audio = self.enhancement_by_tensor(noisy_audio)
64
+ # noisy_audio shape: [n_samples,]
65
+ return enhanced_audio.cpu().numpy()
66
+
67
+ def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
68
+ if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
69
+ raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
70
+
71
+ # noisy_audio shape: [batch_size, num_samples]
72
+ noisy_audios = noisy_audio.to(self.device)
73
+
74
+ with torch.no_grad():
75
+ est_spec, est_wav, est_mask = self.model.forward(noisy_audios)
76
+
77
+ # shape: [batch_size, num_samples]
78
+ enhanced_audio = torch.unsqueeze(est_wav, dim=1)
79
+ # shape: [batch_size, 1, num_samples]
80
+
81
+ enhanced_audio = enhanced_audio[0]
82
+
83
+ # enhanced_audio shape: [channels, num_samples]
84
+ return enhanced_audio
85
+
86
+
87
+ def main():
88
+ model_zip_file = project_path / "trained_models/frcrn-dns3.zip"
89
+ infer_model = InferenceFRCRN(model_zip_file)
90
+
91
+ sample_rate = 8000
92
+ noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_5.wav"
93
+ noisy_audio, sample_rate = librosa.load(
94
+ noisy_audio_file.as_posix(),
95
+ sr=sample_rate,
96
+ )
97
+ duration = librosa.get_duration(y=noisy_audio, sr=sample_rate)
98
+ # noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
99
+ noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
100
+ noisy_audio = noisy_audio.unsqueeze(dim=0)
101
+
102
+ begin = time.time()
103
+ enhanced_audio = infer_model.enhancement_by_tensor(noisy_audio)
104
+ time_cost = time.time() - begin
105
+ print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
106
+
107
+ filename = "enhanced_audio.wav"
108
+ torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
109
+
110
+ return
111
+
112
+
113
+ if __name__ == "__main__":
114
+ main()
toolbox/torchaudio/models/mpnet/yaml/config.yaml CHANGED
@@ -25,3 +25,6 @@ dist_config:
25
  dist_backend: nccl
26
  dist_url: tcp://localhost:54321
27
  world_size: 1
 
 
 
 
25
  dist_backend: nccl
26
  dist_url: tcp://localhost:54321
27
  world_size: 1
28
+
29
+ discriminator_dim: 32
30
+ discriminator_in_channel: 2