Spaces:
Running
Running
update
Browse files
examples/clean_unet_aishell/step_2_train_model.py
CHANGED
@@ -171,6 +171,31 @@ def main():
|
|
171 |
# optimizer
|
172 |
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
|
173 |
optimizer = torch.optim.AdamW(model.parameters(), args.learning_rate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
lr_scheduler = LinearWarmupCosineDecay(
|
175 |
optimizer,
|
176 |
lr_max=args.learning_rate,
|
@@ -180,19 +205,14 @@ def main():
|
|
180 |
warmup_proportion=0.05,
|
181 |
phase=("linear", "cosine"),
|
182 |
)
|
|
|
183 |
# ae_loss_fn = nn.MSELoss(reduction="mean")
|
184 |
ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
|
185 |
|
186 |
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
187 |
-
fft_sizes=[512],
|
188 |
-
hop_sizes=[50],
|
189 |
-
win_lengths=[240],
|
190 |
-
# fft_sizes=[256, 512, 1024],
|
191 |
-
# hop_sizes=[25, 50, 120],
|
192 |
-
# win_lengths=[120, 240, 600],
|
193 |
-
# fft_sizes=[512, 1024, 2048],
|
194 |
-
# hop_sizes=[50, 120, 240],
|
195 |
-
# win_lengths=[240, 600, 1200],
|
196 |
sc_lambda=0.5,
|
197 |
mag_lambda=0.5,
|
198 |
band="full"
|
@@ -343,6 +363,9 @@ def main():
|
|
343 |
model_to_delete: Path = model_list.pop(0)
|
344 |
shutil.rmtree(model_to_delete.as_posix())
|
345 |
|
|
|
|
|
|
|
346 |
# save metric
|
347 |
if best_metric is None:
|
348 |
best_idx_epoch = idx_epoch
|
|
|
171 |
# optimizer
|
172 |
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
|
173 |
optimizer = torch.optim.AdamW(model.parameters(), args.learning_rate)
|
174 |
+
|
175 |
+
# resume training
|
176 |
+
last_epoch = -1
|
177 |
+
for epoch_i in serialization_dir.glob("epoch-*"):
|
178 |
+
epoch_i = Path(epoch_i)
|
179 |
+
epoch_idx = epoch_i.stem.split("-")[1]
|
180 |
+
epoch_idx = int(epoch_idx)
|
181 |
+
if epoch_idx > last_epoch:
|
182 |
+
last_epoch = epoch_idx
|
183 |
+
|
184 |
+
if last_epoch != -1:
|
185 |
+
logger.info(f"resume from epoch-{last_epoch}.")
|
186 |
+
model_pt = serialization_dir / f"epoch-{last_epoch}/model.pt"
|
187 |
+
optimizer_pth = serialization_dir / f"epoch-{last_epoch}/optimizer.pth"
|
188 |
+
|
189 |
+
logger.info(f"load state dict for generator.")
|
190 |
+
with open(model_pt.as_posix(), "rb") as f:
|
191 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
192 |
+
model.load_state_dict(state_dict, strict=True)
|
193 |
+
|
194 |
+
logger.info(f"load state dict for optimizer.")
|
195 |
+
with open(optimizer_pth.as_posix(), "rb") as f:
|
196 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
197 |
+
optimizer.load_state_dict(state_dict)
|
198 |
+
|
199 |
lr_scheduler = LinearWarmupCosineDecay(
|
200 |
optimizer,
|
201 |
lr_max=args.learning_rate,
|
|
|
205 |
warmup_proportion=0.05,
|
206 |
phase=("linear", "cosine"),
|
207 |
)
|
208 |
+
|
209 |
# ae_loss_fn = nn.MSELoss(reduction="mean")
|
210 |
ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
|
211 |
|
212 |
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
213 |
+
fft_sizes=[256, 512, 1024],
|
214 |
+
hop_sizes=[25, 50, 120],
|
215 |
+
win_lengths=[120, 240, 600],
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
sc_lambda=0.5,
|
217 |
mag_lambda=0.5,
|
218 |
band="full"
|
|
|
363 |
model_to_delete: Path = model_list.pop(0)
|
364 |
shutil.rmtree(model_to_delete.as_posix())
|
365 |
|
366 |
+
# save optim
|
367 |
+
torch.save(optimizer.state_dict(), (epoch_dir / "optimizer.pth").as_posix())
|
368 |
+
|
369 |
# save metric
|
370 |
if best_metric is None:
|
371 |
best_idx_epoch = idx_epoch
|