Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -198,7 +198,7 @@ def main():
|
|
198 |
if config.lr_scheduler == "CosineAnnealingLR":
|
199 |
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
200 |
optimizer,
|
201 |
-
last_epoch
|
202 |
# T_max=10 * config.eval_steps,
|
203 |
# eta_min=0.01 * config.lr,
|
204 |
**config.lr_scheduler_kwargs,
|
@@ -250,7 +250,7 @@ def main():
|
|
250 |
total_mr_stft_loss = 0.
|
251 |
total_batches = 0.
|
252 |
|
253 |
-
total_steps = 0
|
254 |
progress_bar_train = tqdm(
|
255 |
desc="Training; epoch-{}".format(idx_epoch),
|
256 |
)
|
|
|
198 |
if config.lr_scheduler == "CosineAnnealingLR":
|
199 |
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
200 |
optimizer,
|
201 |
+
last_epoch=last_epoch,
|
202 |
# T_max=10 * config.eval_steps,
|
203 |
# eta_min=0.01 * config.lr,
|
204 |
**config.lr_scheduler_kwargs,
|
|
|
250 |
total_mr_stft_loss = 0.
|
251 |
total_batches = 0.
|
252 |
|
253 |
+
total_steps = 0 if last_steps == -1 else last_steps
|
254 |
progress_bar_train = tqdm(
|
255 |
desc="Training; epoch-{}".format(idx_epoch),
|
256 |
)
|