HoneyTian commited on
Commit
657b015
·
1 Parent(s): 2dbde0d
examples/spectrum_dfnet_aishell/step_2_train_model.py CHANGED
@@ -313,22 +313,21 @@ def main():
313
  snr_db_target = snr_db.to(device)
314
 
315
  speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
316
- # if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
317
- # raise AssertionError("nan or inf in speech_spec_prediction")
318
  if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
319
  raise AssertionError("nan or inf in speech_irm_prediction")
320
- # if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
321
- # raise AssertionError("nan or inf in lsnr_prediction")
322
 
323
- # speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
324
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
325
- # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
326
 
327
  # if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
328
  # raise AssertionError("nan or inf in snr_loss")
329
 
330
- # loss = speech_loss + irm_loss + snr_loss
331
- loss = irm_loss
332
 
333
  total_loss += loss.item()
334
  total_examples += mix_complex_spec.size(0)
@@ -361,19 +360,18 @@ def main():
361
 
362
  with torch.no_grad():
363
  speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
364
- # if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
365
- # raise AssertionError("nan or inf in speech_spec_prediction")
366
  if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
367
  raise AssertionError("nan or inf in speech_irm_prediction")
368
- # if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
369
- # raise AssertionError("nan or inf in lsnr_prediction")
370
 
371
- # speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
372
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
373
- # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
374
 
375
- # loss = speech_loss + irm_loss + snr_loss
376
- loss = irm_loss
377
 
378
  total_loss += loss.item()
379
  total_examples += mix_complex_spec.size(0)
 
313
  snr_db_target = snr_db.to(device)
314
 
315
  speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
316
+ if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
317
+ raise AssertionError("nan or inf in speech_spec_prediction")
318
  if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
319
  raise AssertionError("nan or inf in speech_irm_prediction")
320
+ if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
321
+ raise AssertionError("nan or inf in lsnr_prediction")
322
 
323
+ speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
324
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
325
+ snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
326
 
327
  # if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
328
  # raise AssertionError("nan or inf in snr_loss")
329
 
330
+ loss = speech_loss + irm_loss + snr_loss
 
331
 
332
  total_loss += loss.item()
333
  total_examples += mix_complex_spec.size(0)
 
360
 
361
  with torch.no_grad():
362
  speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
363
+ if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
364
+ raise AssertionError("nan or inf in speech_spec_prediction")
365
  if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
366
  raise AssertionError("nan or inf in speech_irm_prediction")
367
+ if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
368
+ raise AssertionError("nan or inf in lsnr_prediction")
369
 
370
+ speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
371
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
372
+ snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
373
 
374
+ loss = speech_loss + irm_loss + snr_loss
 
375
 
376
  total_loss += loss.item()
377
  total_examples += mix_complex_spec.size(0)