Spaces:
Running
Running
update
Browse files
examples/spectrum_dfnet_aishell/step_2_train_model.py
CHANGED
@@ -328,7 +328,6 @@ def main():
|
|
328 |
# raise AssertionError("nan or inf in snr_loss")
|
329 |
|
330 |
loss = speech_loss + irm_loss + snr_loss
|
331 |
-
# loss = irm_loss + snr_loss
|
332 |
|
333 |
total_loss += loss.item()
|
334 |
total_examples += mix_complex_spec.size(0)
|
@@ -373,7 +372,6 @@ def main():
|
|
373 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
374 |
|
375 |
loss = speech_loss + irm_loss + snr_loss
|
376 |
-
# loss = irm_loss + snr_loss
|
377 |
|
378 |
total_loss += loss.item()
|
379 |
total_examples += mix_complex_spec.size(0)
|
|
|
328 |
# raise AssertionError("nan or inf in snr_loss")
|
329 |
|
330 |
loss = speech_loss + irm_loss + snr_loss
|
|
|
331 |
|
332 |
total_loss += loss.item()
|
333 |
total_examples += mix_complex_spec.size(0)
|
|
|
372 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
373 |
|
374 |
loss = speech_loss + irm_loss + snr_loss
|
|
|
375 |
|
376 |
total_loss += loss.item()
|
377 |
total_examples += mix_complex_spec.size(0)
|
examples/spectrum_dfnet_aishell/step_3_evaluation.py
CHANGED
@@ -94,8 +94,12 @@ istft = torchaudio.transforms.InverseSpectrogram(
|
|
94 |
)
|
95 |
|
96 |
|
97 |
-
def enhance(mix_spec_complex: torch.Tensor,
|
|
|
|
|
|
|
98 |
mix_spec_complex = mix_spec_complex.detach().cpu()
|
|
|
99 |
speech_irm_prediction = speech_irm_prediction.detach().cpu()
|
100 |
|
101 |
mask_speech = speech_irm_prediction
|
@@ -104,7 +108,8 @@ def enhance(mix_spec_complex: torch.Tensor, speech_irm_prediction: torch.Tensor)
|
|
104 |
speech_spec = mix_spec_complex * mask_speech
|
105 |
noise_spec = mix_spec_complex * mask_noise
|
106 |
|
107 |
-
speech_wave = istft.forward(
|
|
|
108 |
noise_wave = istft.forward(noise_spec)
|
109 |
|
110 |
return speech_wave, noise_wave
|
@@ -212,6 +217,7 @@ def main():
|
|
212 |
speech_spec = speech_spec[:, :-1, :]
|
213 |
mix_spec = mix_spec[:, :-1, :]
|
214 |
|
|
|
215 |
mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave)
|
216 |
# mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
|
217 |
|
@@ -221,6 +227,7 @@ def main():
|
|
221 |
snr_db: torch.Tensor = 10 * torch.log10(
|
222 |
speech_spec / (noise_spec + 1e-8)
|
223 |
)
|
|
|
224 |
snr_db = torch.mean(snr_db, dim=1, keepdim=True)
|
225 |
# snr_db shape: [batch_size, 1, time_steps]
|
226 |
|
@@ -229,7 +236,7 @@ def main():
|
|
229 |
snr_db_target = snr_db.to(device)
|
230 |
|
231 |
with torch.no_grad():
|
232 |
-
speech_irm_prediction, lsnr_prediction = model.forward(
|
233 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
234 |
# snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
235 |
# loss = irm_loss + 0.1 * snr_loss
|
@@ -246,7 +253,7 @@ def main():
|
|
246 |
dim=1,
|
247 |
)
|
248 |
# speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps]
|
249 |
-
speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_irm_prediction)
|
250 |
save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
|
251 |
|
252 |
total_loss += loss.item()
|
|
|
94 |
)
|
95 |
|
96 |
|
97 |
+
def enhance(mix_spec_complex: torch.Tensor,
|
98 |
+
speech_spec_prediction: torch.Tensor,
|
99 |
+
speech_irm_prediction: torch.Tensor,
|
100 |
+
):
|
101 |
mix_spec_complex = mix_spec_complex.detach().cpu()
|
102 |
+
speech_spec_prediction = speech_spec_prediction.detach().cpu()
|
103 |
speech_irm_prediction = speech_irm_prediction.detach().cpu()
|
104 |
|
105 |
mask_speech = speech_irm_prediction
|
|
|
108 |
speech_spec = mix_spec_complex * mask_speech
|
109 |
noise_spec = mix_spec_complex * mask_noise
|
110 |
|
111 |
+
speech_wave = istft.forward(speech_spec_prediction)
|
112 |
+
# speech_wave = istft.forward(speech_spec)
|
113 |
noise_wave = istft.forward(noise_spec)
|
114 |
|
115 |
return speech_wave, noise_wave
|
|
|
217 |
speech_spec = speech_spec[:, :-1, :]
|
218 |
mix_spec = mix_spec[:, :-1, :]
|
219 |
|
220 |
+
speech_spec_complex: torch.Tensor = stft_complex.forward(speech_wave)
|
221 |
mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave)
|
222 |
# mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
|
223 |
|
|
|
227 |
snr_db: torch.Tensor = 10 * torch.log10(
|
228 |
speech_spec / (noise_spec + 1e-8)
|
229 |
)
|
230 |
+
snr_db = torch.clamp(snr_db, min=1e-8)
|
231 |
snr_db = torch.mean(snr_db, dim=1, keepdim=True)
|
232 |
# snr_db shape: [batch_size, 1, time_steps]
|
233 |
|
|
|
236 |
snr_db_target = snr_db.to(device)
|
237 |
|
238 |
with torch.no_grad():
|
239 |
+
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_spec_complex)
|
240 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
241 |
# snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
242 |
# loss = irm_loss + 0.1 * snr_loss
|
|
|
253 |
dim=1,
|
254 |
)
|
255 |
# speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps]
|
256 |
+
speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_spec_prediction, speech_irm_prediction)
|
257 |
save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
|
258 |
|
259 |
total_loss += loss.item()
|