Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -214,6 +214,7 @@ def main():
|
|
214 |
elif config.lr_scheduler == "MultiStepLR":
|
215 |
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
216 |
optimizer,
|
|
|
217 |
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
218 |
)
|
219 |
else:
|
@@ -345,7 +346,7 @@ def main():
|
|
345 |
|
346 |
progress_bar_train.close()
|
347 |
progress_bar_eval = tqdm(
|
348 |
-
desc="Evaluation;
|
349 |
)
|
350 |
for eval_batch in valid_data_loader:
|
351 |
clean_audios, noisy_audios = eval_batch
|
@@ -418,7 +419,7 @@ def main():
|
|
418 |
)
|
419 |
|
420 |
# save path
|
421 |
-
save_dir = serialization_dir / "steps-{}
|
422 |
save_dir.mkdir(parents=True, exist_ok=False)
|
423 |
|
424 |
# save models
|
@@ -455,6 +456,7 @@ def main():
|
|
455 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
456 |
"neg_stoi_loss": average_neg_stoi_loss,
|
457 |
"mr_stft_loss": average_mr_stft_loss,
|
|
|
458 |
}
|
459 |
metrics_filename = save_dir / "metrics_epoch.json"
|
460 |
with open(metrics_filename, "w", encoding="utf-8") as f:
|
|
|
214 |
elif config.lr_scheduler == "MultiStepLR":
|
215 |
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
216 |
optimizer,
|
217 |
+
last_epoch=last_epoch,
|
218 |
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
219 |
)
|
220 |
else:
|
|
|
346 |
|
347 |
progress_bar_train.close()
|
348 |
progress_bar_eval = tqdm(
|
349 |
+
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
350 |
)
|
351 |
for eval_batch in valid_data_loader:
|
352 |
clean_audios, noisy_audios = eval_batch
|
|
|
419 |
)
|
420 |
|
421 |
# save path
|
422 |
+
save_dir = serialization_dir / "steps-{}".format(step_idx)
|
423 |
save_dir.mkdir(parents=True, exist_ok=False)
|
424 |
|
425 |
# save models
|
|
|
456 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
457 |
"neg_stoi_loss": average_neg_stoi_loss,
|
458 |
"mr_stft_loss": average_mr_stft_loss,
|
459 |
+
"pesq_loss": average_pesq_loss,
|
460 |
}
|
461 |
metrics_filename = save_dir / "metrics_epoch.json"
|
462 |
with open(metrics_filename, "w", encoding="utf-8") as f:
|