HoneyTian commited on
Commit
909a27e
·
1 Parent(s): 1292672
examples/dfnet/step_2_train_model.py CHANGED
@@ -263,7 +263,7 @@ def main():
263
 
264
  est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
265
 
266
- print(f"est_mask.shape: {est_mask.shape}")
267
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
268
  mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
269
  # mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
 
263
 
264
  est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
265
 
266
+ print(f"est_mask.shape: {est_mask.shape}, est_mask.dtype: {est_mask.dtype}")
267
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
268
  mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
269
  # mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
toolbox/torchaudio/models/dfnet/modeling_dfnet.py CHANGED
@@ -892,9 +892,22 @@ class DfNet(nn.Module):
892
  # spec_e shape: [batch_size, spec_bins, time_steps, 2]
893
 
894
  mask = torch.squeeze(mask, dim=1)
895
- est_mask = mask.permute(0, 2, 1)
896
- # mask shape: [batch_size, spec_bins, time_steps]
 
 
 
 
 
 
 
 
 
 
 
897
 
 
 
898
  b, _, t, _ = spec_e.shape
899
  est_spec = torch.cat(tensors=[
900
  torch.concat(tensors=[
@@ -906,12 +919,18 @@ class DfNet(nn.Module):
906
  torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
907
  ], dim=1),
908
  ], dim=1)
909
- # est_spec shape: [b, n+2, t]
910
- est_wav = self.istft.forward(est_spec)
911
- est_wav = torch.squeeze(est_wav, dim=1)
912
- est_wav = est_wav[:, :n_samples]
913
- # est_wav shape: [b, n_samples]
914
- return est_spec, est_wav, est_mask, lsnr
 
 
 
 
 
 
915
 
916
  def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
917
  """
@@ -924,35 +943,18 @@ class DfNet(nn.Module):
924
  clean_stft = self.stft(clean)
925
  clean_re = clean_stft[:, :self.freq_bins, :]
926
  clean_im = clean_stft[:, self.freq_bins:, :]
 
927
 
928
  noisy_stft = self.stft(noisy)
929
  noisy_re = noisy_stft[:, :self.freq_bins, :]
930
  noisy_im = noisy_stft[:, self.freq_bins:, :]
931
-
932
  noisy_power = noisy_re ** 2 + noisy_im ** 2
933
 
934
- sr = clean_re
935
- yr = noisy_re
936
- si = clean_im
937
- yi = noisy_im
938
- y_pow = noisy_power
939
- # (Sr * Yr + Si * Yi) / (Y_pow + 1e-8)
940
- gth_mask_re = (sr * yr + si * yi) / (y_pow + self.eps)
941
- # (Si * Yr - Sr * Yi) / (Y_pow + 1e-8)
942
- gth_mask_im = (sr * yr - si * yi) / (y_pow + self.eps)
943
-
944
- gth_mask_re[gth_mask_re > 2] = 1
945
- gth_mask_re[gth_mask_re < -2] = -1
946
- gth_mask_im[gth_mask_im > 2] = 1
947
- gth_mask_im[gth_mask_im < -2] = -1
948
-
949
- mask_re = est_mask[:, :self.freq_bins, :]
950
- mask_im = est_mask[:, self.freq_bins:, :]
951
 
952
- loss_re = F.mse_loss(gth_mask_re, mask_re)
953
- loss_im = F.mse_loss(gth_mask_im, mask_im)
954
 
955
- loss = loss_re + loss_im
956
  return loss
957
 
958
 
 
892
  # spec_e shape: [batch_size, spec_bins, time_steps, 2]
893
 
894
  mask = torch.squeeze(mask, dim=1)
895
+ mask = mask.permute(0, 2, 1)
896
+ # mask shape: [b, 256, t]
897
+ est_mask = self.mask_transfer(mask)
898
+ # est_mask shape: [b, 257, t]
899
+
900
+ # spec_e shape: [b, 256, t, 2]
901
+ est_spec = self.spec_transfer(spec_e)
902
+ # est_spec shape: [b, 257*2, t]
903
+ est_wav = self.istft.forward(est_spec)
904
+ est_wav = torch.squeeze(est_wav, dim=1)
905
+ est_wav = est_wav[:, :n_samples]
906
+ # est_wav shape: [b, n_samples]
907
+ return est_spec, est_wav, est_mask, lsnr
908
 
909
+ def spec_transfer(self, spec_e: torch.Tensor) -> torch.Tensor:
910
+ # spec_e shape: [b, 256, t, 2]
911
  b, _, t, _ = spec_e.shape
912
  est_spec = torch.cat(tensors=[
913
  torch.concat(tensors=[
 
919
  torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
920
  ], dim=1),
921
  ], dim=1)
922
+ # est_spec shape: [b, 257*2, t]
923
+ return est_spec
924
+
925
+ def mask_transfer(self, mask: torch.Tensor) -> torch.Tensor:
926
+ # mask shape: [b, 256, t]
927
+ b, _, t = mask.shape
928
+ est_mask = torch.concat(tensors=[
929
+ mask,
930
+ torch.zeros(size=(b, 1, t), dtype=mask.dtype).to(mask.device)
931
+ ], dim=1)
932
+ # est_mask shape: [b, 257, t]
933
+ return est_mask
934
 
935
  def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
936
  """
 
943
  clean_stft = self.stft(clean)
944
  clean_re = clean_stft[:, :self.freq_bins, :]
945
  clean_im = clean_stft[:, self.freq_bins:, :]
946
+ clean_power = clean_re ** 2 + clean_im ** 2
947
 
948
  noisy_stft = self.stft(noisy)
949
  noisy_re = noisy_stft[:, :self.freq_bins, :]
950
  noisy_im = noisy_stft[:, self.freq_bins:, :]
 
951
  noisy_power = noisy_re ** 2 + noisy_im ** 2
952
 
953
+ speech_irm = clean_power / (noisy_power + self.eps)
954
+ # speech_irm = torch.pow(speech_irm, self.irm_beta)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
955
 
956
+ loss = F.mse_loss(est_mask, speech_irm)
 
957
 
 
958
  return loss
959
 
960