HoneyTian commited on
Commit
4425d40
·
1 Parent(s): 8128494
examples/conv_tasnet/run.sh CHANGED
@@ -3,7 +3,7 @@
3
  : <<'END'
4
 
5
 
6
- sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
7
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
8
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
9
  --max_epochs 400
 
3
  : <<'END'
4
 
5
 
6
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
7
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
8
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
9
  --max_epochs 400
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -171,18 +171,18 @@ def main():
171
  optimizer = torch.optim.AdamW(model.parameters(), config.lr)
172
 
173
  # resume training
174
- last_epoch = -1
175
- for epoch_i in serialization_dir.glob("epoch-*"):
176
- epoch_i = Path(epoch_i)
177
- epoch_idx = epoch_i.stem.split("-")[1]
178
- epoch_idx = int(epoch_idx)
179
- if epoch_idx > last_epoch:
180
- last_epoch = epoch_idx
181
-
182
- if last_epoch != -1:
183
- logger.info(f"resume from epoch-{last_epoch}.")
184
- model_pt = serialization_dir / f"epoch-{last_epoch}/model.pt"
185
- optimizer_pth = serialization_dir / f"epoch-{last_epoch}/optimizer.pth"
186
 
187
  logger.info(f"load state dict for model.")
188
  with open(model_pt.as_posix(), "rb") as f:
@@ -197,7 +197,7 @@ def main():
197
  if config.lr_scheduler == "CosineAnnealingLR":
198
  lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
199
  optimizer,
200
- last_epoch=last_epoch,
201
  # T_max=10 * config.eval_steps,
202
  # eta_min=0.01 * config.lr,
203
  **config.lr_scheduler_kwargs,
 
171
  optimizer = torch.optim.AdamW(model.parameters(), config.lr)
172
 
173
  # resume training
174
+ last_steps = -1
175
+ for step_i in serialization_dir.glob("steps-*"):
176
+ step_i = Path(step_i)
177
+ step_idx = step_i.stem.split("-")[1]
178
+ step_idx = int(step_idx)
179
+ if step_idx > last_steps:
180
+ last_steps = step_idx
181
+
182
+ if last_steps != -1:
183
+ logger.info(f"resume from steps-{last_steps}.")
184
+ model_pt = serialization_dir / f"steps-{last_steps}/model.pt"
185
+ optimizer_pth = serialization_dir / f"steps-{last_steps}/optimizer.pth"
186
 
187
  logger.info(f"load state dict for model.")
188
  with open(model_pt.as_posix(), "rb") as f:
 
197
  if config.lr_scheduler == "CosineAnnealingLR":
198
  lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
199
  optimizer,
200
+ last_epoch=-1,
201
  # T_max=10 * config.eval_steps,
202
  # eta_min=0.01 * config.lr,
203
  **config.lr_scheduler_kwargs,