HoneyTian commited on
Commit
92bf47a
·
1 Parent(s): 8c9b2a3
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -198,7 +198,7 @@ def main():
198
  if config.lr_scheduler == "CosineAnnealingLR":
199
  lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
200
  optimizer,
201
- last_epoch=-1,
202
  # T_max=10 * config.eval_steps,
203
  # eta_min=0.01 * config.lr,
204
  **config.lr_scheduler_kwargs,
@@ -250,7 +250,7 @@ def main():
250
  total_mr_stft_loss = 0.
251
  total_batches = 0.
252
 
253
- total_steps = 0
254
  progress_bar_train = tqdm(
255
  desc="Training; epoch-{}".format(idx_epoch),
256
  )
 
198
  if config.lr_scheduler == "CosineAnnealingLR":
199
  lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
200
  optimizer,
201
+ last_epoch=last_epoch,
202
  # T_max=10 * config.eval_steps,
203
  # eta_min=0.01 * config.lr,
204
  **config.lr_scheduler_kwargs,
 
250
  total_mr_stft_loss = 0.
251
  total_batches = 0.
252
 
253
+ total_steps = 0 if last_steps == -1 else last_steps
254
  progress_bar_train = tqdm(
255
  desc="Training; epoch-{}".format(idx_epoch),
256
  )