Spaces:
Running
Running
update
Browse files
examples/spectrum_dfnet_aishell/step_2_train_model.py
CHANGED
@@ -313,22 +313,21 @@ def main():
|
|
313 |
snr_db_target = snr_db.to(device)
|
314 |
|
315 |
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
|
316 |
-
|
317 |
-
|
318 |
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
319 |
raise AssertionError("nan or inf in speech_irm_prediction")
|
320 |
-
|
321 |
-
|
322 |
|
323 |
-
|
324 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
325 |
-
|
326 |
|
327 |
# if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
328 |
# raise AssertionError("nan or inf in snr_loss")
|
329 |
|
330 |
-
|
331 |
-
loss = irm_loss
|
332 |
|
333 |
total_loss += loss.item()
|
334 |
total_examples += mix_complex_spec.size(0)
|
@@ -361,19 +360,18 @@ def main():
|
|
361 |
|
362 |
with torch.no_grad():
|
363 |
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
|
364 |
-
|
365 |
-
|
366 |
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
367 |
raise AssertionError("nan or inf in speech_irm_prediction")
|
368 |
-
|
369 |
-
|
370 |
|
371 |
-
|
372 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
373 |
-
|
374 |
|
375 |
-
|
376 |
-
loss = irm_loss
|
377 |
|
378 |
total_loss += loss.item()
|
379 |
total_examples += mix_complex_spec.size(0)
|
|
|
313 |
snr_db_target = snr_db.to(device)
|
314 |
|
315 |
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
|
316 |
+
if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
|
317 |
+
raise AssertionError("nan or inf in speech_spec_prediction")
|
318 |
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
319 |
raise AssertionError("nan or inf in speech_irm_prediction")
|
320 |
+
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
321 |
+
raise AssertionError("nan or inf in lsnr_prediction")
|
322 |
|
323 |
+
speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
|
324 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
325 |
+
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
326 |
|
327 |
# if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
328 |
# raise AssertionError("nan or inf in snr_loss")
|
329 |
|
330 |
+
loss = speech_loss + irm_loss + snr_loss
|
|
|
331 |
|
332 |
total_loss += loss.item()
|
333 |
total_examples += mix_complex_spec.size(0)
|
|
|
360 |
|
361 |
with torch.no_grad():
|
362 |
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
|
363 |
+
if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
|
364 |
+
raise AssertionError("nan or inf in speech_spec_prediction")
|
365 |
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
366 |
raise AssertionError("nan or inf in speech_irm_prediction")
|
367 |
+
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
368 |
+
raise AssertionError("nan or inf in lsnr_prediction")
|
369 |
|
370 |
+
speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
|
371 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
372 |
+
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
373 |
|
374 |
+
loss = speech_loss + irm_loss + snr_loss
|
|
|
375 |
|
376 |
total_loss += loss.item()
|
377 |
total_examples += mix_complex_spec.size(0)
|