HoneyTian commited on
Commit
85947fe
·
1 Parent(s): 987be40
examples/clean_unet_aishell/step_2_train_model.py CHANGED
@@ -319,8 +319,12 @@ def main():
319
  enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
320
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
321
  pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
 
 
 
 
322
 
323
- total_pesq_metric += pesq_metric.item()
324
  total_loss += loss.item()
325
  total_ae_loss += ae_loss.item()
326
  total_sc_loss += sc_loss.item()
 
319
  enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
320
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
321
  pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
322
+ if pesq_metric is None:
323
+ pesq_metric = 0
324
+ else:
325
+ pesq_metric = torch.mean(pesq_metric).item()
326
 
327
+ total_pesq_metric += pesq_metric
328
  total_loss += loss.item()
329
  total_ae_loss += ae_loss.item()
330
  total_sc_loss += sc_loss.item()
toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py CHANGED
@@ -144,7 +144,7 @@ class CleanUNet(nn.Module):
144
  nn.Conv1d(channels_h, channels_h * 2, 1),
145
  nn.GLU(dim=1),
146
  nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride),
147
- # nn.ReLU(inplace=False)
148
  ))
149
  channels_output = channels_h
150
 
 
144
  nn.Conv1d(channels_h, channels_h * 2, 1),
145
  nn.GLU(dim=1),
146
  nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride),
147
+ nn.ReLU(inplace=False)
148
  ))
149
  channels_output = channels_h
150