HoneyTian commited on
Commit
3332930
·
1 Parent(s): 7c192b8
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -214,6 +214,7 @@ def main():
214
  elif config.lr_scheduler == "MultiStepLR":
215
  lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
216
  optimizer,
 
217
  milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
218
  )
219
  else:
@@ -345,7 +346,7 @@ def main():
345
 
346
  progress_bar_train.close()
347
  progress_bar_eval = tqdm(
348
- desc="Evaluation; step-{}k".format(int(step_idx/1000)),
349
  )
350
  for eval_batch in valid_data_loader:
351
  clean_audios, noisy_audios = eval_batch
@@ -418,7 +419,7 @@ def main():
418
  )
419
 
420
  # save path
421
- save_dir = serialization_dir / "steps-{}k".format(int(step_idx/1000))
422
  save_dir.mkdir(parents=True, exist_ok=False)
423
 
424
  # save models
@@ -455,6 +456,7 @@ def main():
455
  "neg_si_snr_loss": average_neg_si_snr_loss,
456
  "neg_stoi_loss": average_neg_stoi_loss,
457
  "mr_stft_loss": average_mr_stft_loss,
 
458
  }
459
  metrics_filename = save_dir / "metrics_epoch.json"
460
  with open(metrics_filename, "w", encoding="utf-8") as f:
 
214
  elif config.lr_scheduler == "MultiStepLR":
215
  lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
216
  optimizer,
217
+ last_epoch=last_epoch,
218
  milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
219
  )
220
  else:
 
346
 
347
  progress_bar_train.close()
348
  progress_bar_eval = tqdm(
349
+ desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
350
  )
351
  for eval_batch in valid_data_loader:
352
  clean_audios, noisy_audios = eval_batch
 
419
  )
420
 
421
  # save path
422
+ save_dir = serialization_dir / "steps-{}".format(step_idx)
423
  save_dir.mkdir(parents=True, exist_ok=False)
424
 
425
  # save models
 
456
  "neg_si_snr_loss": average_neg_si_snr_loss,
457
  "neg_stoi_loss": average_neg_stoi_loss,
458
  "mr_stft_loss": average_mr_stft_loss,
459
+ "pesq_loss": average_pesq_loss,
460
  }
461
  metrics_filename = save_dir / "metrics_epoch.json"
462
  with open(metrics_filename, "w", encoding="utf-8") as f: