HoneyTian commited on
Commit
dc01163
·
1 Parent(s): 6c34ab4
examples/conv_tasnet/run.sh CHANGED
@@ -3,7 +3,7 @@
3
  : <<'END'
4
 
5
 
6
- sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
7
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
8
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
9
  --max_epochs 400
 
3
  : <<'END'
4
 
5
 
6
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
7
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
8
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
9
  --max_epochs 400
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -298,21 +298,20 @@ def main():
298
  # evaluation
299
  total_steps += 1
300
  if total_steps % args.eval_steps == 0:
301
- model.eval()
302
- torch.cuda.empty_cache()
303
-
304
- total_pesq_score = 0.
305
- total_loss = 0.
306
- total_ae_loss = 0.
307
- total_neg_si_snr_loss = 0.
308
- total_neg_stoi_loss = 0.
309
- total_batches = 0.
310
-
311
- progress_bar_train.close()
312
- progress_bar_eval = tqdm(
313
- desc="Evaluation; step-{}".format(total_steps),
314
- )
315
  with torch.no_grad():
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  for eval_batch in valid_data_loader:
317
  clean_audios, noisy_audios = eval_batch
318
  clean_audios = clean_audios.to(device)
@@ -357,71 +356,71 @@ def main():
357
  "neg_stoi_loss": average_neg_stoi_loss,
358
  "mr_stft_loss": average_mr_stft_loss,
359
  })
360
- progress_bar_eval.close()
361
- progress_bar_train = tqdm(
362
- initial=progress_bar_train.n,
363
- postfix=progress_bar_train.postfix,
364
- desc=progress_bar_train.desc,
365
- )
366
-
367
- # save path
368
- save_dir = serialization_dir / "steps-{}".format(total_steps)
369
- save_dir.mkdir(parents=True, exist_ok=False)
370
-
371
- # save models
372
- model.save_pretrained(save_dir.as_posix())
373
-
374
- model_list.append(save_dir)
375
- if len(model_list) >= args.num_serialized_models_to_keep:
376
- model_to_delete: Path = model_list.pop(0)
377
- shutil.rmtree(model_to_delete.as_posix())
378
-
379
- # save optim
380
- torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
381
-
382
- # save metric
383
- if best_metric is None:
384
- best_idx_epoch = idx_epoch
385
- best_metric = average_pesq_score
386
- elif average_pesq_score > best_metric:
387
- # great is better.
388
- best_idx_epoch = idx_epoch
389
- best_metric = average_pesq_score
390
- else:
391
- pass
392
-
393
- metrics = {
394
- "idx_epoch": idx_epoch,
395
- "best_idx_epoch": best_idx_epoch,
396
- "pesq_score": average_pesq_score,
397
- "loss": average_loss,
398
- "ae_loss": average_ae_loss,
399
- "neg_si_snr_loss": average_neg_si_snr_loss,
400
- "neg_stoi_loss": average_neg_stoi_loss,
401
- }
402
- metrics_filename = save_dir / "metrics_epoch.json"
403
- with open(metrics_filename, "w", encoding="utf-8") as f:
404
- json.dump(metrics, f, indent=4, ensure_ascii=False)
405
-
406
- # save best
407
- best_dir = serialization_dir / "best"
408
- if best_idx_epoch == idx_epoch:
409
- if best_dir.exists():
410
- shutil.rmtree(best_dir)
411
- shutil.copytree(save_dir, best_dir)
412
-
413
- # early stop
414
- early_stop_flag = False
415
- if best_idx_epoch == idx_epoch:
416
- patience_count = 0
417
- else:
418
- patience_count += 1
419
- if patience_count >= args.patience:
420
- early_stop_flag = True
421
-
422
- # early stop
423
- if early_stop_flag:
424
- break
425
 
426
  return
427
 
 
298
  # evaluation
299
  total_steps += 1
300
  if total_steps % args.eval_steps == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  with torch.no_grad():
302
+ torch.cuda.empty_cache()
303
+
304
+ total_pesq_score = 0.
305
+ total_loss = 0.
306
+ total_ae_loss = 0.
307
+ total_neg_si_snr_loss = 0.
308
+ total_neg_stoi_loss = 0.
309
+ total_batches = 0.
310
+
311
+ progress_bar_train.close()
312
+ progress_bar_eval = tqdm(
313
+ desc="Evaluation; step-{}".format(total_steps),
314
+ )
315
  for eval_batch in valid_data_loader:
316
  clean_audios, noisy_audios = eval_batch
317
  clean_audios = clean_audios.to(device)
 
356
  "neg_stoi_loss": average_neg_stoi_loss,
357
  "mr_stft_loss": average_mr_stft_loss,
358
  })
359
+ progress_bar_eval.close()
360
+ progress_bar_train = tqdm(
361
+ initial=progress_bar_train.n,
362
+ postfix=progress_bar_train.postfix,
363
+ desc=progress_bar_train.desc,
364
+ )
365
+
366
+ # save path
367
+ save_dir = serialization_dir / "steps-{}".format(total_steps)
368
+ save_dir.mkdir(parents=True, exist_ok=False)
369
+
370
+ # save models
371
+ model.save_pretrained(save_dir.as_posix())
372
+
373
+ model_list.append(save_dir)
374
+ if len(model_list) >= args.num_serialized_models_to_keep:
375
+ model_to_delete: Path = model_list.pop(0)
376
+ shutil.rmtree(model_to_delete.as_posix())
377
+
378
+ # save optim
379
+ torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
380
+
381
+ # save metric
382
+ if best_metric is None:
383
+ best_idx_epoch = idx_epoch
384
+ best_metric = average_pesq_score
385
+ elif average_pesq_score > best_metric:
386
+ # great is better.
387
+ best_idx_epoch = idx_epoch
388
+ best_metric = average_pesq_score
389
+ else:
390
+ pass
391
+
392
+ metrics = {
393
+ "idx_epoch": idx_epoch,
394
+ "best_idx_epoch": best_idx_epoch,
395
+ "pesq_score": average_pesq_score,
396
+ "loss": average_loss,
397
+ "ae_loss": average_ae_loss,
398
+ "neg_si_snr_loss": average_neg_si_snr_loss,
399
+ "neg_stoi_loss": average_neg_stoi_loss,
400
+ }
401
+ metrics_filename = save_dir / "metrics_epoch.json"
402
+ with open(metrics_filename, "w", encoding="utf-8") as f:
403
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
404
+
405
+ # save best
406
+ best_dir = serialization_dir / "best"
407
+ if best_idx_epoch == idx_epoch:
408
+ if best_dir.exists():
409
+ shutil.rmtree(best_dir)
410
+ shutil.copytree(save_dir, best_dir)
411
+
412
+ # early stop
413
+ early_stop_flag = False
414
+ if best_idx_epoch == idx_epoch:
415
+ patience_count = 0
416
+ else:
417
+ patience_count += 1
418
+ if patience_count >= args.patience:
419
+ early_stop_flag = True
420
+
421
+ # early stop
422
+ if early_stop_flag:
423
+ break
424
 
425
  return
426