HoneyTian commited on
Commit
b8e167f
·
1 Parent(s): a88ebd1
examples/clean_unet_aishell/step_2_train_model.py CHANGED
@@ -171,6 +171,31 @@ def main():
171
  # optimizer
172
  logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
173
  optimizer = torch.optim.AdamW(model.parameters(), args.learning_rate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  lr_scheduler = LinearWarmupCosineDecay(
175
  optimizer,
176
  lr_max=args.learning_rate,
@@ -180,19 +205,14 @@ def main():
180
  warmup_proportion=0.05,
181
  phase=("linear", "cosine"),
182
  )
 
183
  # ae_loss_fn = nn.MSELoss(reduction="mean")
184
  ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
185
 
186
  mr_stft_loss_fn = MultiResolutionSTFTLoss(
187
- fft_sizes=[512],
188
- hop_sizes=[50],
189
- win_lengths=[240],
190
- # fft_sizes=[256, 512, 1024],
191
- # hop_sizes=[25, 50, 120],
192
- # win_lengths=[120, 240, 600],
193
- # fft_sizes=[512, 1024, 2048],
194
- # hop_sizes=[50, 120, 240],
195
- # win_lengths=[240, 600, 1200],
196
  sc_lambda=0.5,
197
  mag_lambda=0.5,
198
  band="full"
@@ -343,6 +363,9 @@ def main():
343
  model_to_delete: Path = model_list.pop(0)
344
  shutil.rmtree(model_to_delete.as_posix())
345
 
 
 
 
346
  # save metric
347
  if best_metric is None:
348
  best_idx_epoch = idx_epoch
 
171
  # optimizer
172
  logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
173
  optimizer = torch.optim.AdamW(model.parameters(), args.learning_rate)
174
+
175
+ # resume training
176
+ last_epoch = -1
177
+ for epoch_i in serialization_dir.glob("epoch-*"):
178
+ epoch_i = Path(epoch_i)
179
+ epoch_idx = epoch_i.stem.split("-")[1]
180
+ epoch_idx = int(epoch_idx)
181
+ if epoch_idx > last_epoch:
182
+ last_epoch = epoch_idx
183
+
184
+ if last_epoch != -1:
185
+ logger.info(f"resume from epoch-{last_epoch}.")
186
+ model_pt = serialization_dir / f"epoch-{last_epoch}/model.pt"
187
+ optimizer_pth = serialization_dir / f"epoch-{last_epoch}/optimizer.pth"
188
+
189
+ logger.info(f"load state dict for generator.")
190
+ with open(model_pt.as_posix(), "rb") as f:
191
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
192
+ model.load_state_dict(state_dict, strict=True)
193
+
194
+ logger.info(f"load state dict for optimizer.")
195
+ with open(optimizer_pth.as_posix(), "rb") as f:
196
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
197
+ optimizer.load_state_dict(state_dict)
198
+
199
  lr_scheduler = LinearWarmupCosineDecay(
200
  optimizer,
201
  lr_max=args.learning_rate,
 
205
  warmup_proportion=0.05,
206
  phase=("linear", "cosine"),
207
  )
208
+
209
  # ae_loss_fn = nn.MSELoss(reduction="mean")
210
  ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
211
 
212
  mr_stft_loss_fn = MultiResolutionSTFTLoss(
213
+ fft_sizes=[256, 512, 1024],
214
+ hop_sizes=[25, 50, 120],
215
+ win_lengths=[120, 240, 600],
 
 
 
 
 
 
216
  sc_lambda=0.5,
217
  mag_lambda=0.5,
218
  band="full"
 
363
  model_to_delete: Path = model_list.pop(0)
364
  shutil.rmtree(model_to_delete.as_posix())
365
 
366
+ # save optim
367
+ torch.save(optimizer.state_dict(), (epoch_dir / "optimizer.pth").as_posix())
368
+
369
  # save metric
370
  if best_metric is None:
371
  best_idx_epoch = idx_epoch