Spaces:
Running
Running
update
Browse files
examples/clean_unet_aishell/step_2_train_model.py
CHANGED
@@ -319,8 +319,12 @@ def main():
|
|
319 |
enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
|
320 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
321 |
pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
|
|
|
|
|
|
|
|
|
322 |
|
323 |
-
total_pesq_metric += pesq_metric
|
324 |
total_loss += loss.item()
|
325 |
total_ae_loss += ae_loss.item()
|
326 |
total_sc_loss += sc_loss.item()
|
|
|
319 |
enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
|
320 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
321 |
pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
|
322 |
+
if pesq_metric is None:
|
323 |
+
pesq_metric = 0
|
324 |
+
else:
|
325 |
+
pesq_metric = torch.mean(pesq_metric).item()
|
326 |
|
327 |
+
total_pesq_metric += pesq_metric
|
328 |
total_loss += loss.item()
|
329 |
total_ae_loss += ae_loss.item()
|
330 |
total_sc_loss += sc_loss.item()
|
toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py
CHANGED
@@ -144,7 +144,7 @@ class CleanUNet(nn.Module):
|
|
144 |
nn.Conv1d(channels_h, channels_h * 2, 1),
|
145 |
nn.GLU(dim=1),
|
146 |
nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride),
|
147 |
-
|
148 |
))
|
149 |
channels_output = channels_h
|
150 |
|
|
|
144 |
nn.Conv1d(channels_h, channels_h * 2, 1),
|
145 |
nn.GLU(dim=1),
|
146 |
nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride),
|
147 |
+
nn.ReLU(inplace=False)
|
148 |
))
|
149 |
channels_output = channels_h
|
150 |
|