HoneyTian commited on
Commit
63dd56a
·
1 Parent(s): 4fbb8e0
examples/spectrum_dfnet_aishell/step_2_train_model.py CHANGED
@@ -328,7 +328,6 @@ def main():
328
  # raise AssertionError("nan or inf in snr_loss")
329
 
330
  loss = speech_loss + irm_loss + snr_loss
331
- # loss = irm_loss + snr_loss
332
 
333
  total_loss += loss.item()
334
  total_examples += mix_complex_spec.size(0)
@@ -373,7 +372,6 @@ def main():
373
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
374
 
375
  loss = speech_loss + irm_loss + snr_loss
376
- # loss = irm_loss + snr_loss
377
 
378
  total_loss += loss.item()
379
  total_examples += mix_complex_spec.size(0)
 
328
  # raise AssertionError("nan or inf in snr_loss")
329
 
330
  loss = speech_loss + irm_loss + snr_loss
 
331
 
332
  total_loss += loss.item()
333
  total_examples += mix_complex_spec.size(0)
 
372
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
373
 
374
  loss = speech_loss + irm_loss + snr_loss
 
375
 
376
  total_loss += loss.item()
377
  total_examples += mix_complex_spec.size(0)
examples/spectrum_dfnet_aishell/step_3_evaluation.py CHANGED
@@ -94,8 +94,12 @@ istft = torchaudio.transforms.InverseSpectrogram(
94
  )
95
 
96
 
97
- def enhance(mix_spec_complex: torch.Tensor, speech_irm_prediction: torch.Tensor):
 
 
 
98
  mix_spec_complex = mix_spec_complex.detach().cpu()
 
99
  speech_irm_prediction = speech_irm_prediction.detach().cpu()
100
 
101
  mask_speech = speech_irm_prediction
@@ -104,7 +108,8 @@ def enhance(mix_spec_complex: torch.Tensor, speech_irm_prediction: torch.Tensor)
104
  speech_spec = mix_spec_complex * mask_speech
105
  noise_spec = mix_spec_complex * mask_noise
106
 
107
- speech_wave = istft.forward(speech_spec)
 
108
  noise_wave = istft.forward(noise_spec)
109
 
110
  return speech_wave, noise_wave
@@ -212,6 +217,7 @@ def main():
212
  speech_spec = speech_spec[:, :-1, :]
213
  mix_spec = mix_spec[:, :-1, :]
214
 
 
215
  mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave)
216
  # mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
217
 
@@ -221,6 +227,7 @@ def main():
221
  snr_db: torch.Tensor = 10 * torch.log10(
222
  speech_spec / (noise_spec + 1e-8)
223
  )
 
224
  snr_db = torch.mean(snr_db, dim=1, keepdim=True)
225
  # snr_db shape: [batch_size, 1, time_steps]
226
 
@@ -229,7 +236,7 @@ def main():
229
  snr_db_target = snr_db.to(device)
230
 
231
  with torch.no_grad():
232
- speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
233
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
234
  # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
235
  # loss = irm_loss + 0.1 * snr_loss
@@ -246,7 +253,7 @@ def main():
246
  dim=1,
247
  )
248
  # speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps]
249
- speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_irm_prediction)
250
  save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
251
 
252
  total_loss += loss.item()
 
94
  )
95
 
96
 
97
+ def enhance(mix_spec_complex: torch.Tensor,
98
+ speech_spec_prediction: torch.Tensor,
99
+ speech_irm_prediction: torch.Tensor,
100
+ ):
101
  mix_spec_complex = mix_spec_complex.detach().cpu()
102
+ speech_spec_prediction = speech_spec_prediction.detach().cpu()
103
  speech_irm_prediction = speech_irm_prediction.detach().cpu()
104
 
105
  mask_speech = speech_irm_prediction
 
108
  speech_spec = mix_spec_complex * mask_speech
109
  noise_spec = mix_spec_complex * mask_noise
110
 
111
+ speech_wave = istft.forward(speech_spec_prediction)
112
+ # speech_wave = istft.forward(speech_spec)
113
  noise_wave = istft.forward(noise_spec)
114
 
115
  return speech_wave, noise_wave
 
217
  speech_spec = speech_spec[:, :-1, :]
218
  mix_spec = mix_spec[:, :-1, :]
219
 
220
+ speech_spec_complex: torch.Tensor = stft_complex.forward(speech_wave)
221
  mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave)
222
  # mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
223
 
 
227
  snr_db: torch.Tensor = 10 * torch.log10(
228
  speech_spec / (noise_spec + 1e-8)
229
  )
230
+ snr_db = torch.clamp(snr_db, min=1e-8)
231
  snr_db = torch.mean(snr_db, dim=1, keepdim=True)
232
  # snr_db shape: [batch_size, 1, time_steps]
233
 
 
236
  snr_db_target = snr_db.to(device)
237
 
238
  with torch.no_grad():
239
+ speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_spec_complex)
240
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
241
  # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
242
  # loss = irm_loss + 0.1 * snr_loss
 
253
  dim=1,
254
  )
255
  # speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps]
256
+ speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_spec_prediction, speech_irm_prediction)
257
  save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
258
 
259
  total_loss += loss.item()