Spaces:
Running
Running
update
Browse files
examples/nx_mpnet/run.sh
CHANGED
@@ -3,10 +3,11 @@
|
|
3 |
: <<'END'
|
4 |
|
5 |
|
6 |
-
sh run.sh --stage
|
7 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
8 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \
|
9 |
-
--max_epochs 100
|
|
|
10 |
|
11 |
|
12 |
END
|
@@ -26,6 +27,7 @@ limit=10
|
|
26 |
|
27 |
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
28 |
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
|
|
29 |
|
30 |
nohup_name=nohup.out
|
31 |
|
@@ -93,6 +95,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
|
93 |
--speech_dir "${speech_dir}" \
|
94 |
--train_dataset "${train_dataset}" \
|
95 |
--valid_dataset "${valid_dataset}" \
|
|
|
96 |
|
97 |
fi
|
98 |
|
|
|
3 |
: <<'END'
|
4 |
|
5 |
|
6 |
+
sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name nx-mpnet-aishell-20250224 \
|
7 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
8 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \
|
9 |
+
--max_epochs 100 \
|
10 |
+
--duration 2 \
|
11 |
|
12 |
|
13 |
END
|
|
|
27 |
|
28 |
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
29 |
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
30 |
+
duration=2
|
31 |
|
32 |
nohup_name=nohup.out
|
33 |
|
|
|
95 |
--speech_dir "${speech_dir}" \
|
96 |
--train_dataset "${train_dataset}" \
|
97 |
--valid_dataset "${valid_dataset}" \
|
98 |
+
--duration "${duration}" \
|
99 |
|
100 |
fi
|
101 |
|
examples/nx_mpnet/yaml/config.yaml
CHANGED
@@ -15,9 +15,9 @@ mask_hidden_size: 64
|
|
15 |
phase_num_blocks: 4
|
16 |
phase_hidden_size: 64
|
17 |
|
18 |
-
tsfm_hidden_size:
|
19 |
-
tsfm_attention_heads:
|
20 |
-
tsfm_num_blocks:
|
21 |
tsfm_dropout_rate: 0.0
|
22 |
tsfm_max_time_relative_position: 2048
|
23 |
tsfm_max_freq_relative_position: 256
|
|
|
15 |
phase_num_blocks: 4
|
16 |
phase_hidden_size: 64
|
17 |
|
18 |
+
tsfm_hidden_size: 128
|
19 |
+
tsfm_attention_heads: 8
|
20 |
+
tsfm_num_blocks: 6
|
21 |
tsfm_dropout_rate: 0.0
|
22 |
tsfm_max_time_relative_position: 2048
|
23 |
tsfm_max_freq_relative_position: 256
|
toolbox/torchaudio/models/nx_mpnet/inference_mpnet.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import logging
|
4 |
+
from pathlib import Path
|
5 |
+
import shutil
|
6 |
+
import tempfile
|
7 |
+
import zipfile
|
8 |
+
|
9 |
+
import librosa
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torchaudio
|
13 |
+
|
14 |
+
from project_settings import project_path
|
15 |
+
from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig
|
16 |
+
from toolbox.torchaudio.models.nx_mpnet.modeling_nx_mpnet import NXMPNetPretrainedModel, MODEL_FILE
|
17 |
+
from toolbox.torchaudio.models.nx_mpnet.utils import mag_pha_stft, mag_pha_istft
|
18 |
+
|
19 |
+
logger = logging.getLogger("toolbox")
|
20 |
+
|
21 |
+
|
22 |
+
class InferenceNXMPNet(object):
|
23 |
+
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
|
24 |
+
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
25 |
+
self.device = torch.device(device)
|
26 |
+
|
27 |
+
logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
|
28 |
+
config, generator = self.load_models(self.pretrained_model_path_or_zip_file)
|
29 |
+
logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
|
30 |
+
|
31 |
+
self.config = config
|
32 |
+
self.generator = generator
|
33 |
+
self.generator.to(device)
|
34 |
+
self.generator.eval()
|
35 |
+
|
36 |
+
def load_models(self, model_path: str):
|
37 |
+
model_path = Path(model_path)
|
38 |
+
if model_path.name.endswith(".zip"):
|
39 |
+
with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
|
40 |
+
out_root = Path(tempfile.gettempdir()) / "nx_denoise"
|
41 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
42 |
+
f_zip.extractall(path=out_root)
|
43 |
+
model_path = out_root / model_path.stem
|
44 |
+
|
45 |
+
config = NXMPNetConfig.from_pretrained(
|
46 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
47 |
+
)
|
48 |
+
generator = NXMPNetPretrainedModel.from_pretrained(
|
49 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
50 |
+
)
|
51 |
+
generator.to(self.device)
|
52 |
+
generator.eval()
|
53 |
+
|
54 |
+
shutil.rmtree(model_path)
|
55 |
+
return config, generator
|
56 |
+
|
57 |
+
def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
|
58 |
+
if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
|
59 |
+
raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
|
60 |
+
|
61 |
+
noisy_audio = noisy_audio.to(self.device)
|
62 |
+
|
63 |
+
with torch.no_grad():
|
64 |
+
noisy_mag, noisy_pha, noisy_com = mag_pha_stft(
|
65 |
+
noisy_audio,
|
66 |
+
self.config.n_fft, self.config.hop_size, self.config.win_size, self.config.compress_factor
|
67 |
+
)
|
68 |
+
# mag_g, pha_g, com_g = self.generator.forward(noisy_mag, noisy_pha)
|
69 |
+
mag_g, pha_g, com_g = self.generator.forward_chunk_by_chunk(noisy_mag, noisy_pha)
|
70 |
+
audio_g = mag_pha_istft(
|
71 |
+
mag_g, pha_g,
|
72 |
+
self.config.n_fft, self.config.hop_size, self.config.win_size, self.config.compress_factor
|
73 |
+
)
|
74 |
+
enhanced_audio = audio_g.detach()
|
75 |
+
|
76 |
+
enhanced_audio = enhanced_audio[0]
|
77 |
+
return enhanced_audio
|
78 |
+
|
79 |
+
|
80 |
+
def main():
|
81 |
+
model_zip_file = project_path / "trained_models/mpnet-aishell-1-epoch.zip"
|
82 |
+
infer_mpnet = InferenceNXMPNet(model_zip_file)
|
83 |
+
|
84 |
+
sample_rate = 8000
|
85 |
+
noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav"
|
86 |
+
noisy_audio, _ = librosa.load(
|
87 |
+
noisy_audio_file.as_posix(),
|
88 |
+
sr=sample_rate,
|
89 |
+
)
|
90 |
+
noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
|
91 |
+
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
|
92 |
+
noisy_audio = noisy_audio.unsqueeze(dim=0)
|
93 |
+
|
94 |
+
enhanced_audio = infer_mpnet.enhancement_by_tensor(noisy_audio)
|
95 |
+
|
96 |
+
filename = "enhanced_audio.wav"
|
97 |
+
torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
|
98 |
+
|
99 |
+
return
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == '__main__':
|
103 |
+
main()
|
toolbox/torchaudio/models/nx_mpnet/modeling_nx_mpnet.py
CHANGED
@@ -18,6 +18,8 @@ class NXMPNet(nn.Module):
|
|
18 |
config: NXMPNetConfig,
|
19 |
):
|
20 |
super(NXMPNet, self).__init__()
|
|
|
|
|
21 |
self.dense_encoder = DenseEncoder(
|
22 |
num_blocks=config.dense_num_blocks,
|
23 |
in_channels=2,
|
@@ -73,6 +75,91 @@ class NXMPNet(nn.Module):
|
|
73 |
|
74 |
return denoised_amp, denoised_pha, denoised_com
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
MODEL_FILE = "generator.pt"
|
78 |
|
@@ -136,6 +223,11 @@ def main():
|
|
136 |
print(denoised_amp.shape)
|
137 |
print(denoised_pha.shape)
|
138 |
print(denoised_com.shape)
|
|
|
|
|
|
|
|
|
|
|
139 |
return
|
140 |
|
141 |
|
|
|
18 |
config: NXMPNetConfig,
|
19 |
):
|
20 |
super(NXMPNet, self).__init__()
|
21 |
+
self.config = config
|
22 |
+
|
23 |
self.dense_encoder = DenseEncoder(
|
24 |
num_blocks=config.dense_num_blocks,
|
25 |
in_channels=2,
|
|
|
75 |
|
76 |
return denoised_amp, denoised_pha, denoised_com
|
77 |
|
78 |
+
def forward_chunk(self,
|
79 |
+
chunk_noisy_amp: torch.Tensor,
|
80 |
+
chunk_noisy_pha: torch.Tensor,
|
81 |
+
cache: dict,
|
82 |
+
):
|
83 |
+
dense_encoder_cache_pad_list = cache["dense_encoder_cache_pad_list"]
|
84 |
+
mask_decoder_cache_pad_list = cache["mask_decoder_cache_pad_list"]
|
85 |
+
phase_decoder_cache_pad_list = cache["phase_decoder_cache_pad_list"]
|
86 |
+
ts_transformer_cache_att_list = cache["ts_transformer_cache_att_list"]
|
87 |
+
max_att_cache_length = cache["max_att_cache_length"]
|
88 |
+
|
89 |
+
x = torch.stack((chunk_noisy_amp, chunk_noisy_pha), dim=-1).permute(0, 3, 2, 1) # [B, 2, T, F]
|
90 |
+
# x shape: [b, 2, t, f]
|
91 |
+
x, new_dense_encoder_cache_pad_list = self.dense_encoder.forward_chunk(x, cache_pad_list=dense_encoder_cache_pad_list)
|
92 |
+
# x shape: [b, c, t, f//2]
|
93 |
+
|
94 |
+
x, new_ts_transformer_cache_att_list = self.ts_transformer.forward_chunk(
|
95 |
+
x,
|
96 |
+
max_att_cache_length=max_att_cache_length,
|
97 |
+
cache_att_list=ts_transformer_cache_att_list
|
98 |
+
)
|
99 |
+
# x shape: [b, c, t, f//2]
|
100 |
+
|
101 |
+
mask, new_mask_decoder_cache_pad_list = self.mask_decoder.forward_chunk(x, cache_pad_list=mask_decoder_cache_pad_list)
|
102 |
+
denoised_amp = chunk_noisy_amp * mask
|
103 |
+
denoised_pha, new_phase_decoder_cache_pad_list = self.phase_decoder.forward_chunk(x, cache_pad_list=phase_decoder_cache_pad_list)
|
104 |
+
denoised_com = torch.stack(
|
105 |
+
tensors=(
|
106 |
+
denoised_amp * torch.cos(denoised_pha),
|
107 |
+
denoised_amp * torch.sin(denoised_pha)
|
108 |
+
),
|
109 |
+
dim=-1
|
110 |
+
)
|
111 |
+
|
112 |
+
cache = {
|
113 |
+
"dense_encoder_cache_pad_list": new_dense_encoder_cache_pad_list,
|
114 |
+
"mask_decoder_cache_pad_list": new_mask_decoder_cache_pad_list,
|
115 |
+
"phase_decoder_cache_pad_list": new_phase_decoder_cache_pad_list,
|
116 |
+
"ts_transformer_cache_att_list": new_ts_transformer_cache_att_list,
|
117 |
+
"max_att_cache_length": max_att_cache_length,
|
118 |
+
|
119 |
+
}
|
120 |
+
|
121 |
+
return denoised_amp, denoised_pha, denoised_com, cache
|
122 |
+
|
123 |
+
def forward_chunk_by_chunk(self,
|
124 |
+
noisy_amp: torch.Tensor,
|
125 |
+
noisy_pha: torch.Tensor,
|
126 |
+
):
|
127 |
+
"""
|
128 |
+
:param noisy_amp: Tensor, shape: [b, f, t]
|
129 |
+
:param noisy_pha: Tensor, shape: [b, f, t]
|
130 |
+
:return:
|
131 |
+
"""
|
132 |
+
b, f, t = noisy_amp.shape
|
133 |
+
|
134 |
+
max_att_cache_length = (self.config.tsfm_num_left_chunks + self.config.tsfm_num_right_chunks) * self.config.tsfm_chunk_size
|
135 |
+
|
136 |
+
cache = {
|
137 |
+
"dense_encoder_cache_pad_list": None,
|
138 |
+
"mask_decoder_cache_pad_list": None,
|
139 |
+
"phase_decoder_cache_pad_list": None,
|
140 |
+
"ts_transformer_cache_att_list": None,
|
141 |
+
"max_att_cache_length": max_att_cache_length,
|
142 |
+
|
143 |
+
}
|
144 |
+
|
145 |
+
denoised_amp_list = list()
|
146 |
+
denoised_pha_list = list()
|
147 |
+
denoised_com_list = list()
|
148 |
+
|
149 |
+
for idx in range(t):
|
150 |
+
chunk_noisy_amp = noisy_amp[:, :, idx:idx+1]
|
151 |
+
chunk_noisy_pha = noisy_pha[:, :, idx:idx+1]
|
152 |
+
|
153 |
+
denoised_amp, denoised_pha, denoised_com, cache = self.forward_chunk(chunk_noisy_amp, chunk_noisy_pha, cache)
|
154 |
+
denoised_amp_list.append(denoised_amp)
|
155 |
+
denoised_pha_list.append(denoised_pha)
|
156 |
+
denoised_com_list.append(denoised_com)
|
157 |
+
|
158 |
+
denoised_amp_list = torch.concat(denoised_amp_list, dim=2)
|
159 |
+
denoised_pha_list = torch.concat(denoised_pha_list, dim=2)
|
160 |
+
denoised_com_list = torch.concat(denoised_com_list, dim=2)
|
161 |
+
return denoised_amp_list, denoised_pha_list, denoised_com_list
|
162 |
+
|
163 |
|
164 |
MODEL_FILE = "generator.pt"
|
165 |
|
|
|
223 |
print(denoised_amp.shape)
|
224 |
print(denoised_pha.shape)
|
225 |
print(denoised_com.shape)
|
226 |
+
|
227 |
+
denoised_amp, denoised_pha, denoised_com = model.forward_chunk_by_chunk(noisy_amp, noisy_pha)
|
228 |
+
print(denoised_amp.shape)
|
229 |
+
print(denoised_pha.shape)
|
230 |
+
print(denoised_com.shape)
|
231 |
return
|
232 |
|
233 |
|
toolbox/torchaudio/models/nx_mpnet/transformers/transformers.py
CHANGED
@@ -361,13 +361,13 @@ class TSTransformerEncoder(nn.Module):
|
|
361 |
def forward_chunk(self,
|
362 |
xs: torch.Tensor,
|
363 |
max_att_cache_length: int,
|
364 |
-
|
365 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
366 |
"""
|
367 |
|
368 |
:param xs:
|
369 |
:param max_att_cache_length:
|
370 |
-
:param
|
371 |
:return:
|
372 |
"""
|
373 |
# xs shape: [batch_size, channels, time_steps, input_size]
|
@@ -376,19 +376,17 @@ class TSTransformerEncoder(nn.Module):
|
|
376 |
xs = xs.permute(0, 3, 2, 1)
|
377 |
# xs shape: [batch_size, hidden_size, time_steps, input_size]
|
378 |
|
379 |
-
|
380 |
for idx, encoder_layer in enumerate(self.encoder_layer_list):
|
381 |
-
xs,
|
382 |
-
x=xs, attention_cache=
|
383 |
)
|
384 |
# new_att_cache shape: [b*f, n_heads, time_steps, dim]
|
385 |
-
if
|
386 |
begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
|
387 |
end = self.num_right_chunks * self.chunk_size
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
r_att_cache = torch.stack(r_att_cache, dim=0)
|
392 |
|
393 |
# xs shape: [batch_size, hidden_size, time_steps, input_size]
|
394 |
xs = xs.permute(0, 3, 2, 1)
|
@@ -396,7 +394,7 @@ class TSTransformerEncoder(nn.Module):
|
|
396 |
xs = xs.permute(0, 3, 2, 1)
|
397 |
# xs shape: [batch_size, channels, time_steps, input_size]
|
398 |
|
399 |
-
return xs,
|
400 |
|
401 |
def forward_chunk_by_chunk(
|
402 |
self,
|
@@ -406,7 +404,7 @@ class TSTransformerEncoder(nn.Module):
|
|
406 |
batch_size, channels, time_steps, _ = xs.shape
|
407 |
|
408 |
max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
|
409 |
-
|
410 |
|
411 |
outputs = []
|
412 |
for idx in range(0, time_steps, self.chunk_size):
|
@@ -415,10 +413,10 @@ class TSTransformerEncoder(nn.Module):
|
|
415 |
chunk_xs = xs[:, :, begin:end, :]
|
416 |
# chunk_xs shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size]
|
417 |
|
418 |
-
ys,
|
419 |
xs=chunk_xs,
|
420 |
max_att_cache_length=max_att_cache_length,
|
421 |
-
|
422 |
)
|
423 |
# ys shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size]
|
424 |
ys = ys[:, :, :self.chunk_size, :]
|
|
|
361 |
def forward_chunk(self,
|
362 |
xs: torch.Tensor,
|
363 |
max_att_cache_length: int,
|
364 |
+
cache_att_list: List[torch.Tensor] = None,
|
365 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
366 |
"""
|
367 |
|
368 |
:param xs:
|
369 |
:param max_att_cache_length:
|
370 |
+
:param cache_att_list: Tensor, shape: [num_layers, ...]
|
371 |
:return:
|
372 |
"""
|
373 |
# xs shape: [batch_size, channels, time_steps, input_size]
|
|
|
376 |
xs = xs.permute(0, 3, 2, 1)
|
377 |
# xs shape: [batch_size, hidden_size, time_steps, input_size]
|
378 |
|
379 |
+
new_cache_att_list = list()
|
380 |
for idx, encoder_layer in enumerate(self.encoder_layer_list):
|
381 |
+
xs, new_cache_att = encoder_layer.forward(
|
382 |
+
x=xs, attention_cache=cache_att_list[idx] if cache_att_list is not None else None,
|
383 |
)
|
384 |
# new_att_cache shape: [b*f, n_heads, time_steps, dim]
|
385 |
+
if new_cache_att.size(2) > max_att_cache_length:
|
386 |
begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
|
387 |
end = self.num_right_chunks * self.chunk_size
|
388 |
+
new_cache_att = new_cache_att[:, :, -begin:-end, :]
|
389 |
+
new_cache_att_list.append(new_cache_att)
|
|
|
|
|
390 |
|
391 |
# xs shape: [batch_size, hidden_size, time_steps, input_size]
|
392 |
xs = xs.permute(0, 3, 2, 1)
|
|
|
394 |
xs = xs.permute(0, 3, 2, 1)
|
395 |
# xs shape: [batch_size, channels, time_steps, input_size]
|
396 |
|
397 |
+
return xs, new_cache_att_list
|
398 |
|
399 |
def forward_chunk_by_chunk(
|
400 |
self,
|
|
|
404 |
batch_size, channels, time_steps, _ = xs.shape
|
405 |
|
406 |
max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
|
407 |
+
cache_att_list = None
|
408 |
|
409 |
outputs = []
|
410 |
for idx in range(0, time_steps, self.chunk_size):
|
|
|
413 |
chunk_xs = xs[:, :, begin:end, :]
|
414 |
# chunk_xs shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size]
|
415 |
|
416 |
+
ys, cache_att_list = self.forward_chunk(
|
417 |
xs=chunk_xs,
|
418 |
max_att_cache_length=max_att_cache_length,
|
419 |
+
cache_att_list=cache_att_list,
|
420 |
)
|
421 |
# ys shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size]
|
422 |
ys = ys[:, :, :self.chunk_size, :]
|