HoneyTian commited on
Commit
2dbde0d
·
1 Parent(s): 76c7bea
toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py CHANGED
@@ -818,14 +818,14 @@ class SpectrumDfNet(nn.Module):
818
  # feat_spec shape: [batch_size, 2, time_steps, df_bins]
819
  feat_spec = feat_spec.detach()
820
 
821
- # # spec shape: [batch_size, spec_bins, time_steps]
822
- # spec = torch.unsqueeze(spec_complex, dim=1)
823
- # # spec shape: [batch_size, 1, spec_bins, time_steps]
824
- # spec = spec.permute(0, 1, 3, 2)
825
- # # spec shape: [batch_size, 1, time_steps, spec_bins]
826
- # spec = torch.view_as_real(spec)
827
- # # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
828
- # spec = spec.detach()
829
 
830
  e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
831
 
@@ -834,30 +834,30 @@ class SpectrumDfNet(nn.Module):
834
  if torch.any(mask > 1) or torch.any(mask < 0):
835
  raise AssertionError
836
 
837
- # spec_m = self.mask.forward(spec, mask)
838
- #
839
- # # lsnr shape: [batch_size, time_steps, 1]
840
- # lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
841
- # # lsnr shape: [batch_size, 1, time_steps]
842
- #
843
- # df_coefs = self.df_decoder.forward(emb, c0)
844
- # df_coefs = self.df_out_transform(df_coefs)
845
- # # df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
846
- #
847
- # spec_e = self.df_op.forward(spec.clone(), df_coefs)
848
- # # spec_e shape: [batch_size, 1, time_steps, spec_bins, 2]
849
- #
850
- # spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
851
- #
852
- # spec_e = torch.squeeze(spec_e, dim=1)
853
- # spec_e = spec_e.permute(0, 2, 1, 3)
854
- # # spec_e shape: [batch_size, spec_bins, time_steps, 2]
855
 
856
  mask = torch.squeeze(mask, dim=1)
857
  mask = mask.permute(0, 2, 1)
858
  # mask shape: [batch_size, spec_bins, time_steps]
859
 
860
- return None, mask, lsnr
861
 
862
 
863
  class SpectrumDfNetPretrainedModel(SpectrumDfNet):
 
818
  # feat_spec shape: [batch_size, 2, time_steps, df_bins]
819
  feat_spec = feat_spec.detach()
820
 
821
+ # spec shape: [batch_size, spec_bins, time_steps]
822
+ spec = torch.unsqueeze(spec_complex, dim=1)
823
+ # spec shape: [batch_size, 1, spec_bins, time_steps]
824
+ spec = spec.permute(0, 1, 3, 2)
825
+ # spec shape: [batch_size, 1, time_steps, spec_bins]
826
+ spec = torch.view_as_real(spec)
827
+ # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
828
+ spec = spec.detach()
829
 
830
  e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
831
 
 
834
  if torch.any(mask > 1) or torch.any(mask < 0):
835
  raise AssertionError
836
 
837
+ spec_m = self.mask.forward(spec, mask)
838
+
839
+ # lsnr shape: [batch_size, time_steps, 1]
840
+ lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
841
+ # lsnr shape: [batch_size, 1, time_steps]
842
+
843
+ df_coefs = self.df_decoder.forward(emb, c0)
844
+ df_coefs = self.df_out_transform(df_coefs)
845
+ # df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
846
+
847
+ spec_e = self.df_op.forward(spec.clone(), df_coefs)
848
+ # spec_e shape: [batch_size, 1, time_steps, spec_bins, 2]
849
+
850
+ spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
851
+
852
+ spec_e = torch.squeeze(spec_e, dim=1)
853
+ spec_e = spec_e.permute(0, 2, 1, 3)
854
+ # spec_e shape: [batch_size, spec_bins, time_steps, 2]
855
 
856
  mask = torch.squeeze(mask, dim=1)
857
  mask = mask.permute(0, 2, 1)
858
  # mask shape: [batch_size, spec_bins, time_steps]
859
 
860
+ return spec_e, mask, lsnr
861
 
862
 
863
  class SpectrumDfNetPretrainedModel(SpectrumDfNet):