Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/run.sh
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
: <<'END'
|
4 |
|
5 |
|
6 |
-
sh run.sh --stage
|
7 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
8 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
|
9 |
--max_epochs 400
|
|
|
3 |
: <<'END'
|
4 |
|
5 |
|
6 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
|
7 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
8 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
|
9 |
--max_epochs 400
|
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -171,18 +171,18 @@ def main():
|
|
171 |
optimizer = torch.optim.AdamW(model.parameters(), config.lr)
|
172 |
|
173 |
# resume training
|
174 |
-
|
175 |
-
for
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
if
|
180 |
-
|
181 |
-
|
182 |
-
if
|
183 |
-
logger.info(f"resume from
|
184 |
-
model_pt = serialization_dir / f"
|
185 |
-
optimizer_pth = serialization_dir / f"
|
186 |
|
187 |
logger.info(f"load state dict for model.")
|
188 |
with open(model_pt.as_posix(), "rb") as f:
|
@@ -197,7 +197,7 @@ def main():
|
|
197 |
if config.lr_scheduler == "CosineAnnealingLR":
|
198 |
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
199 |
optimizer,
|
200 |
-
last_epoch
|
201 |
# T_max=10 * config.eval_steps,
|
202 |
# eta_min=0.01 * config.lr,
|
203 |
**config.lr_scheduler_kwargs,
|
|
|
171 |
optimizer = torch.optim.AdamW(model.parameters(), config.lr)
|
172 |
|
173 |
# resume training
|
174 |
+
last_steps = -1
|
175 |
+
for step_i in serialization_dir.glob("steps-*"):
|
176 |
+
step_i = Path(step_i)
|
177 |
+
step_idx = step_i.stem.split("-")[1]
|
178 |
+
step_idx = int(step_idx)
|
179 |
+
if step_idx > last_steps:
|
180 |
+
last_steps = step_idx
|
181 |
+
|
182 |
+
if last_steps != -1:
|
183 |
+
logger.info(f"resume from steps-{last_steps}.")
|
184 |
+
model_pt = serialization_dir / f"steps-{last_steps}/model.pt"
|
185 |
+
optimizer_pth = serialization_dir / f"steps-{last_steps}/optimizer.pth"
|
186 |
|
187 |
logger.info(f"load state dict for model.")
|
188 |
with open(model_pt.as_posix(), "rb") as f:
|
|
|
197 |
if config.lr_scheduler == "CosineAnnealingLR":
|
198 |
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
199 |
optimizer,
|
200 |
+
last_epoch=-1,
|
201 |
# T_max=10 * config.eval_steps,
|
202 |
# eta_min=0.01 * config.lr,
|
203 |
**config.lr_scheduler_kwargs,
|