Spaces:
Running
Running
update
Browse files
examples/spectrum_dfnet_aishell/step_3_evaluation.py
CHANGED
@@ -105,10 +105,10 @@ def enhance(mix_spec_complex: torch.Tensor,
|
|
105 |
mask_speech = speech_irm_prediction
|
106 |
mask_noise = 1.0 - speech_irm_prediction
|
107 |
|
108 |
-
print(f"mix_spec_complex: {mix_spec_complex.shape}")
|
109 |
-
print(f"mask_noise: {mask_noise.shape}")
|
110 |
|
111 |
-
|
112 |
noise_spec = mix_spec_complex * mask_noise
|
113 |
|
114 |
speech_wave = istft.forward(speech_spec_prediction)
|
@@ -251,14 +251,14 @@ def main():
|
|
251 |
|
252 |
# mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
|
253 |
# speech_irm_prediction shape: [batch_size, freq_dim (256), time_steps]
|
254 |
-
batch_size, _, time_steps = speech_irm_prediction.shape
|
255 |
-
speech_irm_prediction = torch.concat(
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
)
|
262 |
# speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps]
|
263 |
speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_spec_prediction, speech_irm_prediction)
|
264 |
save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
|
|
|
105 |
mask_speech = speech_irm_prediction
|
106 |
mask_noise = 1.0 - speech_irm_prediction
|
107 |
|
108 |
+
# print(f"mix_spec_complex: {mix_spec_complex.shape}")
|
109 |
+
# print(f"mask_noise: {mask_noise.shape}")
|
110 |
|
111 |
+
speech_spec = mix_spec_complex * mask_speech
|
112 |
noise_spec = mix_spec_complex * mask_noise
|
113 |
|
114 |
speech_wave = istft.forward(speech_spec_prediction)
|
|
|
251 |
|
252 |
# mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
|
253 |
# speech_irm_prediction shape: [batch_size, freq_dim (256), time_steps]
|
254 |
+
# batch_size, _, time_steps = speech_irm_prediction.shape
|
255 |
+
# speech_irm_prediction = torch.concat(
|
256 |
+
# [
|
257 |
+
# speech_irm_prediction,
|
258 |
+
# 0.5*torch.ones(size=(batch_size, 1, time_steps), dtype=speech_irm_prediction.dtype).to(device)
|
259 |
+
# ],
|
260 |
+
# dim=1,
|
261 |
+
# )
|
262 |
# speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps]
|
263 |
speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_spec_prediction, speech_irm_prediction)
|
264 |
save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
|