HoneyTian commited on
Commit
10059e6
·
1 Parent(s): 797d498
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -267,12 +267,10 @@ def main():
267
  initial=step_idx,
268
  desc="Training; epoch-{}".format(epoch_idx),
269
  )
270
- for idx, train_batch in enumerate(train_data_loader):
271
- if idx < step_idx:
272
- continue
273
  clean_audios, noisy_audios = train_batch
274
- clean_audios = clean_audios.to(device)
275
- noisy_audios = noisy_audios.to(device)
276
 
277
  denoise_audios = model.forward(noisy_audios)
278
  denoise_audios = torch.squeeze(denoise_audios, dim=1)
 
267
  initial=step_idx,
268
  desc="Training; epoch-{}".format(epoch_idx),
269
  )
270
+ for train_batch in train_data_loader:
 
 
271
  clean_audios, noisy_audios = train_batch
272
+ clean_audios: torch.Tensor = clean_audios.to(device)
273
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
274
 
275
  denoise_audios = model.forward(noisy_audios)
276
  denoise_audios = torch.squeeze(denoise_audios, dim=1)