Spaces:
Running
Running
update
Browse files- examples/dfnet/run.sh +1 -1
- examples/dfnet/yaml/config.yaml +18 -0
- examples/frcrn/step_2_train_model.py +8 -9
- examples/mpnet/yaml/config.yaml +3 -0
- main.py +7 -0
- toolbox/torchaudio/models/dfnet/configuration_dfnet.py +32 -0
- toolbox/torchaudio/models/frcrn/inference_frcrn.py +114 -0
- toolbox/torchaudio/models/mpnet/yaml/config.yaml +3 -0
examples/dfnet/run.sh
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
: <<'END'
|
4 |
|
5 |
|
6 |
-
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name
|
7 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
8 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
9 |
|
|
|
3 |
: <<'END'
|
4 |
|
5 |
|
6 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet-dns3 \
|
7 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
8 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
9 |
|
examples/dfnet/yaml/config.yaml
CHANGED
@@ -51,3 +51,21 @@ df_lookahead: 2
|
|
51 |
|
52 |
# runtime
|
53 |
use_post_filter: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
# runtime
|
53 |
use_post_filter: true
|
54 |
+
|
55 |
+
# train
|
56 |
+
lr: 0.001
|
57 |
+
lr_scheduler: "CosineAnnealingLR"
|
58 |
+
lr_scheduler_kwargs:
|
59 |
+
T_max: 250000
|
60 |
+
eta_min: 0.0001
|
61 |
+
|
62 |
+
max_epochs: 100
|
63 |
+
clip_grad_norm: 10.0
|
64 |
+
seed: 1234
|
65 |
+
|
66 |
+
min_snr_db: -10
|
67 |
+
max_snr_db: 20
|
68 |
+
|
69 |
+
num_workers: 8
|
70 |
+
batch_size: 32
|
71 |
+
eval_steps: 10000
|
examples/frcrn/step_2_train_model.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
|
|
|
|
4 |
FRCRN 论文中:
|
5 |
在 WSJ0 数据集上训练了 120 个 epoch 得到 pesq 3.62, stoi 98.24, si-snr 21.33
|
6 |
|
@@ -188,17 +190,17 @@ def main():
|
|
188 |
if last_step_idx != -1:
|
189 |
logger.info(f"resume from steps-{last_step_idx}.")
|
190 |
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
191 |
-
optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
|
192 |
|
193 |
logger.info(f"load state dict for model.")
|
194 |
with open(model_pt.as_posix(), "rb") as f:
|
195 |
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
196 |
model.load_state_dict(state_dict, strict=True)
|
197 |
|
198 |
-
logger.info(f"load state dict for optimizer.")
|
199 |
-
with open(optimizer_pth.as_posix(), "rb") as f:
|
200 |
-
|
201 |
-
optimizer.load_state_dict(state_dict)
|
202 |
|
203 |
if config.lr_scheduler == "CosineAnnealingLR":
|
204 |
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
@@ -377,15 +379,12 @@ def main():
|
|
377 |
model_to_delete: Path = model_list.pop(0)
|
378 |
shutil.rmtree(model_to_delete.as_posix())
|
379 |
|
380 |
-
# save optim
|
381 |
-
torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
|
382 |
-
|
383 |
# save metric
|
384 |
if best_metric is None:
|
385 |
best_epoch_idx = epoch_idx
|
386 |
best_step_idx = step_idx
|
387 |
best_metric = average_pesq_score
|
388 |
-
elif average_pesq_score
|
389 |
# great is better.
|
390 |
best_epoch_idx = epoch_idx
|
391 |
best_step_idx = step_idx
|
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
4 |
+
https://arxiv.org/abs/2206.07293
|
5 |
+
|
6 |
FRCRN 论文中:
|
7 |
在 WSJ0 数据集上训练了 120 个 epoch 得到 pesq 3.62, stoi 98.24, si-snr 21.33
|
8 |
|
|
|
190 |
if last_step_idx != -1:
|
191 |
logger.info(f"resume from steps-{last_step_idx}.")
|
192 |
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
193 |
+
# optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
|
194 |
|
195 |
logger.info(f"load state dict for model.")
|
196 |
with open(model_pt.as_posix(), "rb") as f:
|
197 |
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
198 |
model.load_state_dict(state_dict, strict=True)
|
199 |
|
200 |
+
# logger.info(f"load state dict for optimizer.")
|
201 |
+
# with open(optimizer_pth.as_posix(), "rb") as f:
|
202 |
+
# state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
203 |
+
# optimizer.load_state_dict(state_dict)
|
204 |
|
205 |
if config.lr_scheduler == "CosineAnnealingLR":
|
206 |
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
|
379 |
model_to_delete: Path = model_list.pop(0)
|
380 |
shutil.rmtree(model_to_delete.as_posix())
|
381 |
|
|
|
|
|
|
|
382 |
# save metric
|
383 |
if best_metric is None:
|
384 |
best_epoch_idx = epoch_idx
|
385 |
best_step_idx = step_idx
|
386 |
best_metric = average_pesq_score
|
387 |
+
elif average_pesq_score >= best_metric:
|
388 |
# great is better.
|
389 |
best_epoch_idx = epoch_idx
|
390 |
best_step_idx = step_idx
|
examples/mpnet/yaml/config.yaml
CHANGED
@@ -25,3 +25,6 @@ dist_config:
|
|
25 |
dist_backend: nccl
|
26 |
dist_url: tcp://localhost:54321
|
27 |
world_size: 1
|
|
|
|
|
|
|
|
25 |
dist_backend: nccl
|
26 |
dist_url: tcp://localhost:54321
|
27 |
world_size: 1
|
28 |
+
|
29 |
+
discriminator_dim: 32
|
30 |
+
discriminator_in_channel: 2
|
main.py
CHANGED
@@ -16,6 +16,7 @@ import log
|
|
16 |
from project_settings import environment, project_path, log_directory
|
17 |
from toolbox.os.command import Command
|
18 |
from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet
|
|
|
19 |
|
20 |
log.setup_size_rotating(log_directory=log_directory)
|
21 |
|
@@ -93,6 +94,12 @@ denoise_engines = {
|
|
93 |
"pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-aishell-11-epoch.zip").as_posix()
|
94 |
}
|
95 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
}
|
97 |
|
98 |
|
|
|
16 |
from project_settings import environment, project_path, log_directory
|
17 |
from toolbox.os.command import Command
|
18 |
from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet
|
19 |
+
from toolbox.torchaudio.models.frcrn.inference_frcrn import InferenceFRCRN
|
20 |
|
21 |
log.setup_size_rotating(log_directory=log_directory)
|
22 |
|
|
|
94 |
"pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-aishell-11-epoch.zip").as_posix()
|
95 |
}
|
96 |
},
|
97 |
+
"frcrn-dns3": {
|
98 |
+
"infer_cls": InferenceFRCRN,
|
99 |
+
"kwargs": {
|
100 |
+
"pretrained_model_path_or_zip_file": (project_path / "trained_models/frcrn-dns3-220k-steps.zip").as_posix()
|
101 |
+
}
|
102 |
+
},
|
103 |
}
|
104 |
|
105 |
|
toolbox/torchaudio/models/dfnet/configuration_dfnet.py
CHANGED
@@ -50,6 +50,22 @@ class DfNetConfig(PretrainedConfig):
|
|
50 |
df_lookahead: int = 2,
|
51 |
|
52 |
use_post_filter: bool = False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
**kwargs
|
54 |
):
|
55 |
super(DfNetConfig, self).__init__(**kwargs)
|
@@ -104,6 +120,22 @@ class DfNetConfig(PretrainedConfig):
|
|
104 |
# runtime
|
105 |
self.use_post_filter = use_post_filter
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
if __name__ == "__main__":
|
109 |
pass
|
|
|
50 |
df_lookahead: int = 2,
|
51 |
|
52 |
use_post_filter: bool = False,
|
53 |
+
|
54 |
+
lr: float = 0.001,
|
55 |
+
lr_scheduler: str = "CosineAnnealingLR",
|
56 |
+
lr_scheduler_kwargs: dict = None,
|
57 |
+
|
58 |
+
max_epochs: int = 100,
|
59 |
+
clip_grad_norm: float = 10.,
|
60 |
+
seed: int = 1234,
|
61 |
+
|
62 |
+
min_snr_db: float = -10,
|
63 |
+
max_snr_db: float = 20,
|
64 |
+
|
65 |
+
num_workers: int = 4,
|
66 |
+
batch_size: int = 4,
|
67 |
+
eval_steps: int = 25000,
|
68 |
+
|
69 |
**kwargs
|
70 |
):
|
71 |
super(DfNetConfig, self).__init__(**kwargs)
|
|
|
120 |
# runtime
|
121 |
self.use_post_filter = use_post_filter
|
122 |
|
123 |
+
#
|
124 |
+
self.lr = lr
|
125 |
+
self.lr_scheduler = lr_scheduler
|
126 |
+
self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
|
127 |
+
|
128 |
+
self.max_epochs = max_epochs
|
129 |
+
self.clip_grad_norm = clip_grad_norm
|
130 |
+
self.seed = seed
|
131 |
+
|
132 |
+
self.min_snr_db = min_snr_db
|
133 |
+
self.max_snr_db = max_snr_db
|
134 |
+
|
135 |
+
self.num_workers = num_workers
|
136 |
+
self.batch_size = batch_size
|
137 |
+
self.eval_steps = eval_steps
|
138 |
+
|
139 |
|
140 |
if __name__ == "__main__":
|
141 |
pass
|
toolbox/torchaudio/models/frcrn/inference_frcrn.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import logging
|
4 |
+
from pathlib import Path
|
5 |
+
import shutil
|
6 |
+
import tempfile, time
|
7 |
+
import zipfile
|
8 |
+
|
9 |
+
import librosa
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torchaudio
|
13 |
+
|
14 |
+
torch.set_num_threads(1)
|
15 |
+
|
16 |
+
from project_settings import project_path
|
17 |
+
from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig
|
18 |
+
from toolbox.torchaudio.models.frcrn.modeling_frcrn import FRCRNPretrainedModel, MODEL_FILE
|
19 |
+
|
20 |
+
logger = logging.getLogger("toolbox")
|
21 |
+
|
22 |
+
|
23 |
+
class InferenceFRCRN(object):
|
24 |
+
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
|
25 |
+
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
26 |
+
self.device = torch.device(device)
|
27 |
+
|
28 |
+
logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
|
29 |
+
config, model = self.load_models(self.pretrained_model_path_or_zip_file)
|
30 |
+
logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
|
31 |
+
|
32 |
+
self.config = config
|
33 |
+
self.model = model
|
34 |
+
self.model.to(device)
|
35 |
+
self.model.eval()
|
36 |
+
|
37 |
+
def load_models(self, model_path: str):
|
38 |
+
model_path = Path(model_path)
|
39 |
+
if model_path.name.endswith(".zip"):
|
40 |
+
with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
|
41 |
+
out_root = Path(tempfile.gettempdir()) / "nx_denoise"
|
42 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
43 |
+
f_zip.extractall(path=out_root)
|
44 |
+
model_path = out_root / model_path.stem
|
45 |
+
|
46 |
+
config = FRCRNConfig.from_pretrained(
|
47 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
48 |
+
)
|
49 |
+
model = FRCRNPretrainedModel.from_pretrained(
|
50 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
51 |
+
)
|
52 |
+
model.to(self.device)
|
53 |
+
model.eval()
|
54 |
+
|
55 |
+
shutil.rmtree(model_path)
|
56 |
+
return config, model
|
57 |
+
|
58 |
+
def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray:
|
59 |
+
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
|
60 |
+
noisy_audio = noisy_audio.unsqueeze(dim=0)
|
61 |
+
|
62 |
+
# noisy_audio shape: [batch_size, n_samples]
|
63 |
+
enhanced_audio = self.enhancement_by_tensor(noisy_audio)
|
64 |
+
# noisy_audio shape: [n_samples,]
|
65 |
+
return enhanced_audio.cpu().numpy()
|
66 |
+
|
67 |
+
def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
|
68 |
+
if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
|
69 |
+
raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
|
70 |
+
|
71 |
+
# noisy_audio shape: [batch_size, num_samples]
|
72 |
+
noisy_audios = noisy_audio.to(self.device)
|
73 |
+
|
74 |
+
with torch.no_grad():
|
75 |
+
est_spec, est_wav, est_mask = self.model.forward(noisy_audios)
|
76 |
+
|
77 |
+
# shape: [batch_size, num_samples]
|
78 |
+
enhanced_audio = torch.unsqueeze(est_wav, dim=1)
|
79 |
+
# shape: [batch_size, 1, num_samples]
|
80 |
+
|
81 |
+
enhanced_audio = enhanced_audio[0]
|
82 |
+
|
83 |
+
# enhanced_audio shape: [channels, num_samples]
|
84 |
+
return enhanced_audio
|
85 |
+
|
86 |
+
|
87 |
+
def main():
|
88 |
+
model_zip_file = project_path / "trained_models/frcrn-dns3.zip"
|
89 |
+
infer_model = InferenceFRCRN(model_zip_file)
|
90 |
+
|
91 |
+
sample_rate = 8000
|
92 |
+
noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_5.wav"
|
93 |
+
noisy_audio, sample_rate = librosa.load(
|
94 |
+
noisy_audio_file.as_posix(),
|
95 |
+
sr=sample_rate,
|
96 |
+
)
|
97 |
+
duration = librosa.get_duration(y=noisy_audio, sr=sample_rate)
|
98 |
+
# noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
|
99 |
+
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
|
100 |
+
noisy_audio = noisy_audio.unsqueeze(dim=0)
|
101 |
+
|
102 |
+
begin = time.time()
|
103 |
+
enhanced_audio = infer_model.enhancement_by_tensor(noisy_audio)
|
104 |
+
time_cost = time.time() - begin
|
105 |
+
print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
|
106 |
+
|
107 |
+
filename = "enhanced_audio.wav"
|
108 |
+
torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
|
109 |
+
|
110 |
+
return
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
main()
|
toolbox/torchaudio/models/mpnet/yaml/config.yaml
CHANGED
@@ -25,3 +25,6 @@ dist_config:
|
|
25 |
dist_backend: nccl
|
26 |
dist_url: tcp://localhost:54321
|
27 |
world_size: 1
|
|
|
|
|
|
|
|
25 |
dist_backend: nccl
|
26 |
dist_url: tcp://localhost:54321
|
27 |
world_size: 1
|
28 |
+
|
29 |
+
discriminator_dim: 32
|
30 |
+
discriminator_in_channel: 2
|