HoneyTian commited on
Commit
7e91720
·
1 Parent(s): 8cf37ea
examples/spectrum_unet_irm_aishell/step_2_train_model.py CHANGED
@@ -292,7 +292,7 @@ def main():
292
  speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
293
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
294
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
295
- loss = irm_loss + 0.1 * snr_loss
296
  # loss = irm_loss
297
 
298
  total_loss += loss.item()
@@ -327,7 +327,7 @@ def main():
327
  speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
328
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
329
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
330
- loss = irm_loss + 0.1 * snr_loss
331
  # loss = irm_loss
332
 
333
  total_loss += loss.item()
 
292
  speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
293
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
294
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
295
+ loss = irm_loss + 0.01 * snr_loss
296
  # loss = irm_loss
297
 
298
  total_loss += loss.item()
 
327
  speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
328
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
329
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
330
+ loss = irm_loss + 0.01 * snr_loss
331
  # loss = irm_loss
332
 
333
  total_loss += loss.item()