Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py
CHANGED
@@ -818,14 +818,14 @@ class SpectrumDfNet(nn.Module):
|
|
818 |
# feat_spec shape: [batch_size, 2, time_steps, df_bins]
|
819 |
feat_spec = feat_spec.detach()
|
820 |
|
821 |
-
#
|
822 |
-
|
823 |
-
#
|
824 |
-
|
825 |
-
#
|
826 |
-
|
827 |
-
#
|
828 |
-
|
829 |
|
830 |
e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
|
831 |
|
@@ -834,30 +834,30 @@ class SpectrumDfNet(nn.Module):
|
|
834 |
if torch.any(mask > 1) or torch.any(mask < 0):
|
835 |
raise AssertionError
|
836 |
|
837 |
-
|
838 |
-
|
839 |
-
#
|
840 |
-
|
841 |
-
#
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
#
|
846 |
-
|
847 |
-
|
848 |
-
#
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
#
|
855 |
|
856 |
mask = torch.squeeze(mask, dim=1)
|
857 |
mask = mask.permute(0, 2, 1)
|
858 |
# mask shape: [batch_size, spec_bins, time_steps]
|
859 |
|
860 |
-
return
|
861 |
|
862 |
|
863 |
class SpectrumDfNetPretrainedModel(SpectrumDfNet):
|
|
|
818 |
# feat_spec shape: [batch_size, 2, time_steps, df_bins]
|
819 |
feat_spec = feat_spec.detach()
|
820 |
|
821 |
+
# spec shape: [batch_size, spec_bins, time_steps]
|
822 |
+
spec = torch.unsqueeze(spec_complex, dim=1)
|
823 |
+
# spec shape: [batch_size, 1, spec_bins, time_steps]
|
824 |
+
spec = spec.permute(0, 1, 3, 2)
|
825 |
+
# spec shape: [batch_size, 1, time_steps, spec_bins]
|
826 |
+
spec = torch.view_as_real(spec)
|
827 |
+
# spec shape: [batch_size, 1, time_steps, spec_bins, 2]
|
828 |
+
spec = spec.detach()
|
829 |
|
830 |
e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
|
831 |
|
|
|
834 |
if torch.any(mask > 1) or torch.any(mask < 0):
|
835 |
raise AssertionError
|
836 |
|
837 |
+
spec_m = self.mask.forward(spec, mask)
|
838 |
+
|
839 |
+
# lsnr shape: [batch_size, time_steps, 1]
|
840 |
+
lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
|
841 |
+
# lsnr shape: [batch_size, 1, time_steps]
|
842 |
+
|
843 |
+
df_coefs = self.df_decoder.forward(emb, c0)
|
844 |
+
df_coefs = self.df_out_transform(df_coefs)
|
845 |
+
# df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
|
846 |
+
|
847 |
+
spec_e = self.df_op.forward(spec.clone(), df_coefs)
|
848 |
+
# spec_e shape: [batch_size, 1, time_steps, spec_bins, 2]
|
849 |
+
|
850 |
+
spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
|
851 |
+
|
852 |
+
spec_e = torch.squeeze(spec_e, dim=1)
|
853 |
+
spec_e = spec_e.permute(0, 2, 1, 3)
|
854 |
+
# spec_e shape: [batch_size, spec_bins, time_steps, 2]
|
855 |
|
856 |
mask = torch.squeeze(mask, dim=1)
|
857 |
mask = mask.permute(0, 2, 1)
|
858 |
# mask shape: [batch_size, spec_bins, time_steps]
|
859 |
|
860 |
+
return spec_e, mask, lsnr
|
861 |
|
862 |
|
863 |
class SpectrumDfNetPretrainedModel(SpectrumDfNet):
|