Spaces:
Running
Running
update
Browse files
examples/dfnet/step_2_train_model.py
CHANGED
@@ -263,7 +263,7 @@ def main():
|
|
263 |
|
264 |
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
265 |
|
266 |
-
print(f"est_mask.shape: {est_mask.shape}")
|
267 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
268 |
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
269 |
# mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
|
|
263 |
|
264 |
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
265 |
|
266 |
+
print(f"est_mask.shape: {est_mask.shape}, est_mask.dtype: {est_mask.dtype}")
|
267 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
268 |
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
269 |
# mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
toolbox/torchaudio/models/dfnet/modeling_dfnet.py
CHANGED
@@ -892,9 +892,22 @@ class DfNet(nn.Module):
|
|
892 |
# spec_e shape: [batch_size, spec_bins, time_steps, 2]
|
893 |
|
894 |
mask = torch.squeeze(mask, dim=1)
|
895 |
-
|
896 |
-
# mask shape: [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
897 |
|
|
|
|
|
898 |
b, _, t, _ = spec_e.shape
|
899 |
est_spec = torch.cat(tensors=[
|
900 |
torch.concat(tensors=[
|
@@ -906,12 +919,18 @@ class DfNet(nn.Module):
|
|
906 |
torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
|
907 |
], dim=1),
|
908 |
], dim=1)
|
909 |
-
# est_spec shape: [b,
|
910 |
-
|
911 |
-
|
912 |
-
|
913 |
-
#
|
914 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
915 |
|
916 |
def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
|
917 |
"""
|
@@ -924,35 +943,18 @@ class DfNet(nn.Module):
|
|
924 |
clean_stft = self.stft(clean)
|
925 |
clean_re = clean_stft[:, :self.freq_bins, :]
|
926 |
clean_im = clean_stft[:, self.freq_bins:, :]
|
|
|
927 |
|
928 |
noisy_stft = self.stft(noisy)
|
929 |
noisy_re = noisy_stft[:, :self.freq_bins, :]
|
930 |
noisy_im = noisy_stft[:, self.freq_bins:, :]
|
931 |
-
|
932 |
noisy_power = noisy_re ** 2 + noisy_im ** 2
|
933 |
|
934 |
-
|
935 |
-
|
936 |
-
si = clean_im
|
937 |
-
yi = noisy_im
|
938 |
-
y_pow = noisy_power
|
939 |
-
# (Sr * Yr + Si * Yi) / (Y_pow + 1e-8)
|
940 |
-
gth_mask_re = (sr * yr + si * yi) / (y_pow + self.eps)
|
941 |
-
# (Si * Yr - Sr * Yi) / (Y_pow + 1e-8)
|
942 |
-
gth_mask_im = (sr * yr - si * yi) / (y_pow + self.eps)
|
943 |
-
|
944 |
-
gth_mask_re[gth_mask_re > 2] = 1
|
945 |
-
gth_mask_re[gth_mask_re < -2] = -1
|
946 |
-
gth_mask_im[gth_mask_im > 2] = 1
|
947 |
-
gth_mask_im[gth_mask_im < -2] = -1
|
948 |
-
|
949 |
-
mask_re = est_mask[:, :self.freq_bins, :]
|
950 |
-
mask_im = est_mask[:, self.freq_bins:, :]
|
951 |
|
952 |
-
|
953 |
-
loss_im = F.mse_loss(gth_mask_im, mask_im)
|
954 |
|
955 |
-
loss = loss_re + loss_im
|
956 |
return loss
|
957 |
|
958 |
|
|
|
892 |
# spec_e shape: [batch_size, spec_bins, time_steps, 2]
|
893 |
|
894 |
mask = torch.squeeze(mask, dim=1)
|
895 |
+
mask = mask.permute(0, 2, 1)
|
896 |
+
# mask shape: [b, 256, t]
|
897 |
+
est_mask = self.mask_transfer(mask)
|
898 |
+
# est_mask shape: [b, 257, t]
|
899 |
+
|
900 |
+
# spec_e shape: [b, 256, t, 2]
|
901 |
+
est_spec = self.spec_transfer(spec_e)
|
902 |
+
# est_spec shape: [b, 257*2, t]
|
903 |
+
est_wav = self.istft.forward(est_spec)
|
904 |
+
est_wav = torch.squeeze(est_wav, dim=1)
|
905 |
+
est_wav = est_wav[:, :n_samples]
|
906 |
+
# est_wav shape: [b, n_samples]
|
907 |
+
return est_spec, est_wav, est_mask, lsnr
|
908 |
|
909 |
+
def spec_transfer(self, spec_e: torch.Tensor) -> torch.Tensor:
|
910 |
+
# spec_e shape: [b, 256, t, 2]
|
911 |
b, _, t, _ = spec_e.shape
|
912 |
est_spec = torch.cat(tensors=[
|
913 |
torch.concat(tensors=[
|
|
|
919 |
torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
|
920 |
], dim=1),
|
921 |
], dim=1)
|
922 |
+
# est_spec shape: [b, 257*2, t]
|
923 |
+
return est_spec
|
924 |
+
|
925 |
+
def mask_transfer(self, mask: torch.Tensor) -> torch.Tensor:
|
926 |
+
# mask shape: [b, 256, t]
|
927 |
+
b, _, t = mask.shape
|
928 |
+
est_mask = torch.concat(tensors=[
|
929 |
+
mask,
|
930 |
+
torch.zeros(size=(b, 1, t), dtype=mask.dtype).to(mask.device)
|
931 |
+
], dim=1)
|
932 |
+
# est_mask shape: [b, 257, t]
|
933 |
+
return est_mask
|
934 |
|
935 |
def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
|
936 |
"""
|
|
|
943 |
clean_stft = self.stft(clean)
|
944 |
clean_re = clean_stft[:, :self.freq_bins, :]
|
945 |
clean_im = clean_stft[:, self.freq_bins:, :]
|
946 |
+
clean_power = clean_re ** 2 + clean_im ** 2
|
947 |
|
948 |
noisy_stft = self.stft(noisy)
|
949 |
noisy_re = noisy_stft[:, :self.freq_bins, :]
|
950 |
noisy_im = noisy_stft[:, self.freq_bins:, :]
|
|
|
951 |
noisy_power = noisy_re ** 2 + noisy_im ** 2
|
952 |
|
953 |
+
speech_irm = clean_power / (noisy_power + self.eps)
|
954 |
+
# speech_irm = torch.pow(speech_irm, self.irm_beta)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
955 |
|
956 |
+
loss = F.mse_loss(est_mask, speech_irm)
|
|
|
957 |
|
|
|
958 |
return loss
|
959 |
|
960 |
|