HoneyTian commited on
Commit
ce96198
·
1 Parent(s): c6c50f4
examples/nx_mpnet/run.sh CHANGED
@@ -3,10 +3,11 @@
3
  : <<'END'
4
 
5
 
6
- sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name nx-mpnet-aishell-20250224 \
7
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
8
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \
9
- --max_epochs 100
 
10
 
11
 
12
  END
@@ -26,6 +27,7 @@ limit=10
26
 
27
  noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
28
  speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
 
29
 
30
  nohup_name=nohup.out
31
 
@@ -93,6 +95,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
93
  --speech_dir "${speech_dir}" \
94
  --train_dataset "${train_dataset}" \
95
  --valid_dataset "${valid_dataset}" \
 
96
 
97
  fi
98
 
 
3
  : <<'END'
4
 
5
 
6
+ sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name nx-mpnet-aishell-20250224 \
7
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
8
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \
9
+ --max_epochs 100 \
10
+ --duration 2 \
11
 
12
 
13
  END
 
27
 
28
  noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
29
  speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
30
+ duration=2
31
 
32
  nohup_name=nohup.out
33
 
 
95
  --speech_dir "${speech_dir}" \
96
  --train_dataset "${train_dataset}" \
97
  --valid_dataset "${valid_dataset}" \
98
+ --duration "${duration}" \
99
 
100
  fi
101
 
examples/nx_mpnet/yaml/config.yaml CHANGED
@@ -15,9 +15,9 @@ mask_hidden_size: 64
15
  phase_num_blocks: 4
16
  phase_hidden_size: 64
17
 
18
- tsfm_hidden_size: 64
19
- tsfm_attention_heads: 4
20
- tsfm_num_blocks: 4
21
  tsfm_dropout_rate: 0.0
22
  tsfm_max_time_relative_position: 2048
23
  tsfm_max_freq_relative_position: 256
 
15
  phase_num_blocks: 4
16
  phase_hidden_size: 64
17
 
18
+ tsfm_hidden_size: 128
19
+ tsfm_attention_heads: 8
20
+ tsfm_num_blocks: 6
21
  tsfm_dropout_rate: 0.0
22
  tsfm_max_time_relative_position: 2048
23
  tsfm_max_freq_relative_position: 256
toolbox/torchaudio/models/nx_mpnet/inference_mpnet.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import logging
4
+ from pathlib import Path
5
+ import shutil
6
+ import tempfile
7
+ import zipfile
8
+
9
+ import librosa
10
+ import numpy as np
11
+ import torch
12
+ import torchaudio
13
+
14
+ from project_settings import project_path
15
+ from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig
16
+ from toolbox.torchaudio.models.nx_mpnet.modeling_nx_mpnet import NXMPNetPretrainedModel, MODEL_FILE
17
+ from toolbox.torchaudio.models.nx_mpnet.utils import mag_pha_stft, mag_pha_istft
18
+
19
+ logger = logging.getLogger("toolbox")
20
+
21
+
22
+ class InferenceNXMPNet(object):
23
+ def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
24
+ self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
25
+ self.device = torch.device(device)
26
+
27
+ logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
28
+ config, generator = self.load_models(self.pretrained_model_path_or_zip_file)
29
+ logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
30
+
31
+ self.config = config
32
+ self.generator = generator
33
+ self.generator.to(device)
34
+ self.generator.eval()
35
+
36
+ def load_models(self, model_path: str):
37
+ model_path = Path(model_path)
38
+ if model_path.name.endswith(".zip"):
39
+ with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
40
+ out_root = Path(tempfile.gettempdir()) / "nx_denoise"
41
+ out_root.mkdir(parents=True, exist_ok=True)
42
+ f_zip.extractall(path=out_root)
43
+ model_path = out_root / model_path.stem
44
+
45
+ config = NXMPNetConfig.from_pretrained(
46
+ pretrained_model_name_or_path=model_path.as_posix(),
47
+ )
48
+ generator = NXMPNetPretrainedModel.from_pretrained(
49
+ pretrained_model_name_or_path=model_path.as_posix(),
50
+ )
51
+ generator.to(self.device)
52
+ generator.eval()
53
+
54
+ shutil.rmtree(model_path)
55
+ return config, generator
56
+
57
+ def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
58
+ if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
59
+ raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
60
+
61
+ noisy_audio = noisy_audio.to(self.device)
62
+
63
+ with torch.no_grad():
64
+ noisy_mag, noisy_pha, noisy_com = mag_pha_stft(
65
+ noisy_audio,
66
+ self.config.n_fft, self.config.hop_size, self.config.win_size, self.config.compress_factor
67
+ )
68
+ # mag_g, pha_g, com_g = self.generator.forward(noisy_mag, noisy_pha)
69
+ mag_g, pha_g, com_g = self.generator.forward_chunk_by_chunk(noisy_mag, noisy_pha)
70
+ audio_g = mag_pha_istft(
71
+ mag_g, pha_g,
72
+ self.config.n_fft, self.config.hop_size, self.config.win_size, self.config.compress_factor
73
+ )
74
+ enhanced_audio = audio_g.detach()
75
+
76
+ enhanced_audio = enhanced_audio[0]
77
+ return enhanced_audio
78
+
79
+
80
+ def main():
81
+ model_zip_file = project_path / "trained_models/mpnet-aishell-1-epoch.zip"
82
+ infer_mpnet = InferenceNXMPNet(model_zip_file)
83
+
84
+ sample_rate = 8000
85
+ noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav"
86
+ noisy_audio, _ = librosa.load(
87
+ noisy_audio_file.as_posix(),
88
+ sr=sample_rate,
89
+ )
90
+ noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
91
+ noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
92
+ noisy_audio = noisy_audio.unsqueeze(dim=0)
93
+
94
+ enhanced_audio = infer_mpnet.enhancement_by_tensor(noisy_audio)
95
+
96
+ filename = "enhanced_audio.wav"
97
+ torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
98
+
99
+ return
100
+
101
+
102
+ if __name__ == '__main__':
103
+ main()
toolbox/torchaudio/models/nx_mpnet/modeling_nx_mpnet.py CHANGED
@@ -18,6 +18,8 @@ class NXMPNet(nn.Module):
18
  config: NXMPNetConfig,
19
  ):
20
  super(NXMPNet, self).__init__()
 
 
21
  self.dense_encoder = DenseEncoder(
22
  num_blocks=config.dense_num_blocks,
23
  in_channels=2,
@@ -73,6 +75,91 @@ class NXMPNet(nn.Module):
73
 
74
  return denoised_amp, denoised_pha, denoised_com
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  MODEL_FILE = "generator.pt"
78
 
@@ -136,6 +223,11 @@ def main():
136
  print(denoised_amp.shape)
137
  print(denoised_pha.shape)
138
  print(denoised_com.shape)
 
 
 
 
 
139
  return
140
 
141
 
 
18
  config: NXMPNetConfig,
19
  ):
20
  super(NXMPNet, self).__init__()
21
+ self.config = config
22
+
23
  self.dense_encoder = DenseEncoder(
24
  num_blocks=config.dense_num_blocks,
25
  in_channels=2,
 
75
 
76
  return denoised_amp, denoised_pha, denoised_com
77
 
78
+ def forward_chunk(self,
79
+ chunk_noisy_amp: torch.Tensor,
80
+ chunk_noisy_pha: torch.Tensor,
81
+ cache: dict,
82
+ ):
83
+ dense_encoder_cache_pad_list = cache["dense_encoder_cache_pad_list"]
84
+ mask_decoder_cache_pad_list = cache["mask_decoder_cache_pad_list"]
85
+ phase_decoder_cache_pad_list = cache["phase_decoder_cache_pad_list"]
86
+ ts_transformer_cache_att_list = cache["ts_transformer_cache_att_list"]
87
+ max_att_cache_length = cache["max_att_cache_length"]
88
+
89
+ x = torch.stack((chunk_noisy_amp, chunk_noisy_pha), dim=-1).permute(0, 3, 2, 1) # [B, 2, T, F]
90
+ # x shape: [b, 2, t, f]
91
+ x, new_dense_encoder_cache_pad_list = self.dense_encoder.forward_chunk(x, cache_pad_list=dense_encoder_cache_pad_list)
92
+ # x shape: [b, c, t, f//2]
93
+
94
+ x, new_ts_transformer_cache_att_list = self.ts_transformer.forward_chunk(
95
+ x,
96
+ max_att_cache_length=max_att_cache_length,
97
+ cache_att_list=ts_transformer_cache_att_list
98
+ )
99
+ # x shape: [b, c, t, f//2]
100
+
101
+ mask, new_mask_decoder_cache_pad_list = self.mask_decoder.forward_chunk(x, cache_pad_list=mask_decoder_cache_pad_list)
102
+ denoised_amp = chunk_noisy_amp * mask
103
+ denoised_pha, new_phase_decoder_cache_pad_list = self.phase_decoder.forward_chunk(x, cache_pad_list=phase_decoder_cache_pad_list)
104
+ denoised_com = torch.stack(
105
+ tensors=(
106
+ denoised_amp * torch.cos(denoised_pha),
107
+ denoised_amp * torch.sin(denoised_pha)
108
+ ),
109
+ dim=-1
110
+ )
111
+
112
+ cache = {
113
+ "dense_encoder_cache_pad_list": new_dense_encoder_cache_pad_list,
114
+ "mask_decoder_cache_pad_list": new_mask_decoder_cache_pad_list,
115
+ "phase_decoder_cache_pad_list": new_phase_decoder_cache_pad_list,
116
+ "ts_transformer_cache_att_list": new_ts_transformer_cache_att_list,
117
+ "max_att_cache_length": max_att_cache_length,
118
+
119
+ }
120
+
121
+ return denoised_amp, denoised_pha, denoised_com, cache
122
+
123
+ def forward_chunk_by_chunk(self,
124
+ noisy_amp: torch.Tensor,
125
+ noisy_pha: torch.Tensor,
126
+ ):
127
+ """
128
+ :param noisy_amp: Tensor, shape: [b, f, t]
129
+ :param noisy_pha: Tensor, shape: [b, f, t]
130
+ :return:
131
+ """
132
+ b, f, t = noisy_amp.shape
133
+
134
+ max_att_cache_length = (self.config.tsfm_num_left_chunks + self.config.tsfm_num_right_chunks) * self.config.tsfm_chunk_size
135
+
136
+ cache = {
137
+ "dense_encoder_cache_pad_list": None,
138
+ "mask_decoder_cache_pad_list": None,
139
+ "phase_decoder_cache_pad_list": None,
140
+ "ts_transformer_cache_att_list": None,
141
+ "max_att_cache_length": max_att_cache_length,
142
+
143
+ }
144
+
145
+ denoised_amp_list = list()
146
+ denoised_pha_list = list()
147
+ denoised_com_list = list()
148
+
149
+ for idx in range(t):
150
+ chunk_noisy_amp = noisy_amp[:, :, idx:idx+1]
151
+ chunk_noisy_pha = noisy_pha[:, :, idx:idx+1]
152
+
153
+ denoised_amp, denoised_pha, denoised_com, cache = self.forward_chunk(chunk_noisy_amp, chunk_noisy_pha, cache)
154
+ denoised_amp_list.append(denoised_amp)
155
+ denoised_pha_list.append(denoised_pha)
156
+ denoised_com_list.append(denoised_com)
157
+
158
+ denoised_amp_list = torch.concat(denoised_amp_list, dim=2)
159
+ denoised_pha_list = torch.concat(denoised_pha_list, dim=2)
160
+ denoised_com_list = torch.concat(denoised_com_list, dim=2)
161
+ return denoised_amp_list, denoised_pha_list, denoised_com_list
162
+
163
 
164
  MODEL_FILE = "generator.pt"
165
 
 
223
  print(denoised_amp.shape)
224
  print(denoised_pha.shape)
225
  print(denoised_com.shape)
226
+
227
+ denoised_amp, denoised_pha, denoised_com = model.forward_chunk_by_chunk(noisy_amp, noisy_pha)
228
+ print(denoised_amp.shape)
229
+ print(denoised_pha.shape)
230
+ print(denoised_com.shape)
231
  return
232
 
233
 
toolbox/torchaudio/models/nx_mpnet/transformers/transformers.py CHANGED
@@ -361,13 +361,13 @@ class TSTransformerEncoder(nn.Module):
361
  def forward_chunk(self,
362
  xs: torch.Tensor,
363
  max_att_cache_length: int,
364
- attention_cache: torch.Tensor = None,
365
  ) -> Tuple[torch.Tensor, torch.Tensor]:
366
  """
367
 
368
  :param xs:
369
  :param max_att_cache_length:
370
- :param attention_cache: Tensor, shape: [num_layers, ...]
371
  :return:
372
  """
373
  # xs shape: [batch_size, channels, time_steps, input_size]
@@ -376,19 +376,17 @@ class TSTransformerEncoder(nn.Module):
376
  xs = xs.permute(0, 3, 2, 1)
377
  # xs shape: [batch_size, hidden_size, time_steps, input_size]
378
 
379
- r_att_cache = []
380
  for idx, encoder_layer in enumerate(self.encoder_layer_list):
381
- xs, new_att_cache = encoder_layer.forward(
382
- x=xs, attention_cache=attention_cache[idx] if attention_cache is not None else None,
383
  )
384
  # new_att_cache shape: [b*f, n_heads, time_steps, dim]
385
- if new_att_cache.size(2) > max_att_cache_length:
386
  begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
387
  end = self.num_right_chunks * self.chunk_size
388
- new_att_cache = new_att_cache[:, :, -begin:-end, :]
389
- r_att_cache.append(new_att_cache)
390
-
391
- r_att_cache = torch.stack(r_att_cache, dim=0)
392
 
393
  # xs shape: [batch_size, hidden_size, time_steps, input_size]
394
  xs = xs.permute(0, 3, 2, 1)
@@ -396,7 +394,7 @@ class TSTransformerEncoder(nn.Module):
396
  xs = xs.permute(0, 3, 2, 1)
397
  # xs shape: [batch_size, channels, time_steps, input_size]
398
 
399
- return xs, r_att_cache
400
 
401
  def forward_chunk_by_chunk(
402
  self,
@@ -406,7 +404,7 @@ class TSTransformerEncoder(nn.Module):
406
  batch_size, channels, time_steps, _ = xs.shape
407
 
408
  max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
409
- attention_cache = None
410
 
411
  outputs = []
412
  for idx in range(0, time_steps, self.chunk_size):
@@ -415,10 +413,10 @@ class TSTransformerEncoder(nn.Module):
415
  chunk_xs = xs[:, :, begin:end, :]
416
  # chunk_xs shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size]
417
 
418
- ys, attention_cache = self.forward_chunk(
419
  xs=chunk_xs,
420
  max_att_cache_length=max_att_cache_length,
421
- attention_cache=attention_cache,
422
  )
423
  # ys shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size]
424
  ys = ys[:, :, :self.chunk_size, :]
 
361
  def forward_chunk(self,
362
  xs: torch.Tensor,
363
  max_att_cache_length: int,
364
+ cache_att_list: List[torch.Tensor] = None,
365
  ) -> Tuple[torch.Tensor, torch.Tensor]:
366
  """
367
 
368
  :param xs:
369
  :param max_att_cache_length:
370
+ :param cache_att_list: Tensor, shape: [num_layers, ...]
371
  :return:
372
  """
373
  # xs shape: [batch_size, channels, time_steps, input_size]
 
376
  xs = xs.permute(0, 3, 2, 1)
377
  # xs shape: [batch_size, hidden_size, time_steps, input_size]
378
 
379
+ new_cache_att_list = list()
380
  for idx, encoder_layer in enumerate(self.encoder_layer_list):
381
+ xs, new_cache_att = encoder_layer.forward(
382
+ x=xs, attention_cache=cache_att_list[idx] if cache_att_list is not None else None,
383
  )
384
  # new_att_cache shape: [b*f, n_heads, time_steps, dim]
385
+ if new_cache_att.size(2) > max_att_cache_length:
386
  begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
387
  end = self.num_right_chunks * self.chunk_size
388
+ new_cache_att = new_cache_att[:, :, -begin:-end, :]
389
+ new_cache_att_list.append(new_cache_att)
 
 
390
 
391
  # xs shape: [batch_size, hidden_size, time_steps, input_size]
392
  xs = xs.permute(0, 3, 2, 1)
 
394
  xs = xs.permute(0, 3, 2, 1)
395
  # xs shape: [batch_size, channels, time_steps, input_size]
396
 
397
+ return xs, new_cache_att_list
398
 
399
  def forward_chunk_by_chunk(
400
  self,
 
404
  batch_size, channels, time_steps, _ = xs.shape
405
 
406
  max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
407
+ cache_att_list = None
408
 
409
  outputs = []
410
  for idx in range(0, time_steps, self.chunk_size):
 
413
  chunk_xs = xs[:, :, begin:end, :]
414
  # chunk_xs shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size]
415
 
416
+ ys, cache_att_list = self.forward_chunk(
417
  xs=chunk_xs,
418
  max_att_cache_length=max_att_cache_length,
419
+ cache_att_list=cache_att_list,
420
  )
421
  # ys shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size]
422
  ys = ys[:, :, :self.chunk_size, :]