HoneyTian commited on
Commit
b8db969
·
1 Parent(s): f33a053
examples/spectrum_dfnet_aishell/step_3_evaluation.py CHANGED
@@ -105,10 +105,10 @@ def enhance(mix_spec_complex: torch.Tensor,
105
  mask_speech = speech_irm_prediction
106
  mask_noise = 1.0 - speech_irm_prediction
107
 
108
- print(f"mix_spec_complex: {mix_spec_complex.shape}")
109
- print(f"mask_noise: {mask_noise.shape}")
110
 
111
- # speech_spec = mix_spec_complex * mask_speech
112
  noise_spec = mix_spec_complex * mask_noise
113
 
114
  speech_wave = istft.forward(speech_spec_prediction)
@@ -251,14 +251,14 @@ def main():
251
 
252
  # mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
253
  # speech_irm_prediction shape: [batch_size, freq_dim (256), time_steps]
254
- batch_size, _, time_steps = speech_irm_prediction.shape
255
- speech_irm_prediction = torch.concat(
256
- [
257
- speech_irm_prediction,
258
- 0.5*torch.ones(size=(batch_size, 1, time_steps), dtype=speech_irm_prediction.dtype).to(device)
259
- ],
260
- dim=1,
261
- )
262
  # speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps]
263
  speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_spec_prediction, speech_irm_prediction)
264
  save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
 
105
  mask_speech = speech_irm_prediction
106
  mask_noise = 1.0 - speech_irm_prediction
107
 
108
+ # print(f"mix_spec_complex: {mix_spec_complex.shape}")
109
+ # print(f"mask_noise: {mask_noise.shape}")
110
 
111
+ speech_spec = mix_spec_complex * mask_speech
112
  noise_spec = mix_spec_complex * mask_noise
113
 
114
  speech_wave = istft.forward(speech_spec_prediction)
 
251
 
252
  # mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
253
  # speech_irm_prediction shape: [batch_size, freq_dim (256), time_steps]
254
+ # batch_size, _, time_steps = speech_irm_prediction.shape
255
+ # speech_irm_prediction = torch.concat(
256
+ # [
257
+ # speech_irm_prediction,
258
+ # 0.5*torch.ones(size=(batch_size, 1, time_steps), dtype=speech_irm_prediction.dtype).to(device)
259
+ # ],
260
+ # dim=1,
261
+ # )
262
  # speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps]
263
  speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_spec_prediction, speech_irm_prediction)
264
  save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)