Spaces:
Running
Running
update
Browse files
examples/spectrum_unet_irm_aishell/step_2_train_model.py
CHANGED
@@ -292,7 +292,7 @@ def main():
|
|
292 |
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
293 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
294 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
295 |
-
loss = irm_loss + 0.
|
296 |
# loss = irm_loss
|
297 |
|
298 |
total_loss += loss.item()
|
@@ -327,7 +327,7 @@ def main():
|
|
327 |
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
328 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
329 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
330 |
-
loss = irm_loss + 0.
|
331 |
# loss = irm_loss
|
332 |
|
333 |
total_loss += loss.item()
|
|
|
292 |
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
293 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
294 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
295 |
+
loss = irm_loss + 0.01 * snr_loss
|
296 |
# loss = irm_loss
|
297 |
|
298 |
total_loss += loss.item()
|
|
|
327 |
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
328 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
329 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
330 |
+
loss = irm_loss + 0.01 * snr_loss
|
331 |
# loss = irm_loss
|
332 |
|
333 |
total_loss += loss.item()
|