Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/run.sh
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
: <<'END'
|
4 |
|
5 |
|
6 |
-
sh run.sh --stage
|
7 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
8 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
|
9 |
--max_epochs 400
|
|
|
3 |
: <<'END'
|
4 |
|
5 |
|
6 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
|
7 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
8 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
|
9 |
--max_epochs 400
|
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -298,21 +298,20 @@ def main():
|
|
298 |
# evaluation
|
299 |
total_steps += 1
|
300 |
if total_steps % args.eval_steps == 0:
|
301 |
-
model.eval()
|
302 |
-
torch.cuda.empty_cache()
|
303 |
-
|
304 |
-
total_pesq_score = 0.
|
305 |
-
total_loss = 0.
|
306 |
-
total_ae_loss = 0.
|
307 |
-
total_neg_si_snr_loss = 0.
|
308 |
-
total_neg_stoi_loss = 0.
|
309 |
-
total_batches = 0.
|
310 |
-
|
311 |
-
progress_bar_train.close()
|
312 |
-
progress_bar_eval = tqdm(
|
313 |
-
desc="Evaluation; step-{}".format(total_steps),
|
314 |
-
)
|
315 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
for eval_batch in valid_data_loader:
|
317 |
clean_audios, noisy_audios = eval_batch
|
318 |
clean_audios = clean_audios.to(device)
|
@@ -357,71 +356,71 @@ def main():
|
|
357 |
"neg_stoi_loss": average_neg_stoi_loss,
|
358 |
"mr_stft_loss": average_mr_stft_loss,
|
359 |
})
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
|
426 |
return
|
427 |
|
|
|
298 |
# evaluation
|
299 |
total_steps += 1
|
300 |
if total_steps % args.eval_steps == 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
with torch.no_grad():
|
302 |
+
torch.cuda.empty_cache()
|
303 |
+
|
304 |
+
total_pesq_score = 0.
|
305 |
+
total_loss = 0.
|
306 |
+
total_ae_loss = 0.
|
307 |
+
total_neg_si_snr_loss = 0.
|
308 |
+
total_neg_stoi_loss = 0.
|
309 |
+
total_batches = 0.
|
310 |
+
|
311 |
+
progress_bar_train.close()
|
312 |
+
progress_bar_eval = tqdm(
|
313 |
+
desc="Evaluation; step-{}".format(total_steps),
|
314 |
+
)
|
315 |
for eval_batch in valid_data_loader:
|
316 |
clean_audios, noisy_audios = eval_batch
|
317 |
clean_audios = clean_audios.to(device)
|
|
|
356 |
"neg_stoi_loss": average_neg_stoi_loss,
|
357 |
"mr_stft_loss": average_mr_stft_loss,
|
358 |
})
|
359 |
+
progress_bar_eval.close()
|
360 |
+
progress_bar_train = tqdm(
|
361 |
+
initial=progress_bar_train.n,
|
362 |
+
postfix=progress_bar_train.postfix,
|
363 |
+
desc=progress_bar_train.desc,
|
364 |
+
)
|
365 |
+
|
366 |
+
# save path
|
367 |
+
save_dir = serialization_dir / "steps-{}".format(total_steps)
|
368 |
+
save_dir.mkdir(parents=True, exist_ok=False)
|
369 |
+
|
370 |
+
# save models
|
371 |
+
model.save_pretrained(save_dir.as_posix())
|
372 |
+
|
373 |
+
model_list.append(save_dir)
|
374 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
375 |
+
model_to_delete: Path = model_list.pop(0)
|
376 |
+
shutil.rmtree(model_to_delete.as_posix())
|
377 |
+
|
378 |
+
# save optim
|
379 |
+
torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
|
380 |
+
|
381 |
+
# save metric
|
382 |
+
if best_metric is None:
|
383 |
+
best_idx_epoch = idx_epoch
|
384 |
+
best_metric = average_pesq_score
|
385 |
+
elif average_pesq_score > best_metric:
|
386 |
+
# great is better.
|
387 |
+
best_idx_epoch = idx_epoch
|
388 |
+
best_metric = average_pesq_score
|
389 |
+
else:
|
390 |
+
pass
|
391 |
+
|
392 |
+
metrics = {
|
393 |
+
"idx_epoch": idx_epoch,
|
394 |
+
"best_idx_epoch": best_idx_epoch,
|
395 |
+
"pesq_score": average_pesq_score,
|
396 |
+
"loss": average_loss,
|
397 |
+
"ae_loss": average_ae_loss,
|
398 |
+
"neg_si_snr_loss": average_neg_si_snr_loss,
|
399 |
+
"neg_stoi_loss": average_neg_stoi_loss,
|
400 |
+
}
|
401 |
+
metrics_filename = save_dir / "metrics_epoch.json"
|
402 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
403 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
404 |
+
|
405 |
+
# save best
|
406 |
+
best_dir = serialization_dir / "best"
|
407 |
+
if best_idx_epoch == idx_epoch:
|
408 |
+
if best_dir.exists():
|
409 |
+
shutil.rmtree(best_dir)
|
410 |
+
shutil.copytree(save_dir, best_dir)
|
411 |
+
|
412 |
+
# early stop
|
413 |
+
early_stop_flag = False
|
414 |
+
if best_idx_epoch == idx_epoch:
|
415 |
+
patience_count = 0
|
416 |
+
else:
|
417 |
+
patience_count += 1
|
418 |
+
if patience_count >= args.patience:
|
419 |
+
early_stop_flag = True
|
420 |
+
|
421 |
+
# early stop
|
422 |
+
if early_stop_flag:
|
423 |
+
break
|
424 |
|
425 |
return
|
426 |
|