HoneyTian commited on
Commit
76c7bea
·
1 Parent(s): af4c931
toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py CHANGED
@@ -807,6 +807,7 @@ class SpectrumDfNet(nn.Module):
807
  # feat_power shape: [batch_size, spec_bins, time_steps]
808
  # feat_power shape: [batch_size, 1, spec_bins, time_steps]
809
  # feat_power shape: [batch_size, 1, time_steps, spec_bins]
 
810
 
811
  # spec shape: [batch_size, spec_bins, time_steps]
812
  feat_spec = torch.view_as_real(spec_complex)
@@ -815,14 +816,16 @@ class SpectrumDfNet(nn.Module):
815
  # feat_spec shape: [batch_size, 2, time_steps, spec_bins]
816
  feat_spec = feat_spec[..., :self.df_decoder.df_bins]
817
  # feat_spec shape: [batch_size, 2, time_steps, df_bins]
818
-
819
- # spec shape: [batch_size, spec_bins, time_steps]
820
- spec = torch.unsqueeze(spec_complex, dim=1)
821
- # spec shape: [batch_size, 1, spec_bins, time_steps]
822
- spec = spec.permute(0, 1, 3, 2)
823
- # spec shape: [batch_size, 1, time_steps, spec_bins]
824
- spec = torch.view_as_real(spec)
825
- # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
 
 
826
 
827
  e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
828
 
 
807
  # feat_power shape: [batch_size, spec_bins, time_steps]
808
  # feat_power shape: [batch_size, 1, spec_bins, time_steps]
809
  # feat_power shape: [batch_size, 1, time_steps, spec_bins]
810
+ feat_power = feat_power.detach()
811
 
812
  # spec shape: [batch_size, spec_bins, time_steps]
813
  feat_spec = torch.view_as_real(spec_complex)
 
816
  # feat_spec shape: [batch_size, 2, time_steps, spec_bins]
817
  feat_spec = feat_spec[..., :self.df_decoder.df_bins]
818
  # feat_spec shape: [batch_size, 2, time_steps, df_bins]
819
+ feat_spec = feat_spec.detach()
820
+
821
+ # # spec shape: [batch_size, spec_bins, time_steps]
822
+ # spec = torch.unsqueeze(spec_complex, dim=1)
823
+ # # spec shape: [batch_size, 1, spec_bins, time_steps]
824
+ # spec = spec.permute(0, 1, 3, 2)
825
+ # # spec shape: [batch_size, 1, time_steps, spec_bins]
826
+ # spec = torch.view_as_real(spec)
827
+ # # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
828
+ # spec = spec.detach()
829
 
830
  e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
831