Spaces:
Running
Running
update
Browse files- examples/dfnet/run.sh +3 -0
- examples/dfnet/step_1_prepare_data.py +1 -1
- examples/dfnet/step_2_train_model.py +35 -17
- examples/dfnet/yaml/config.yaml +13 -10
- examples/frcrn/step_1_prepare_data.py +1 -1
- toolbox/torchaudio/losses/irm.py +63 -0
- toolbox/torchaudio/losses/snr.py +83 -0
- toolbox/torchaudio/models/dfnet/configuration_dfnet.py +22 -18
- toolbox/torchaudio/models/dfnet/conv_stft.py +31 -29
- toolbox/torchaudio/models/dfnet/modeling_dfnet.py +114 -58
- toolbox/torchaudio/modules/__init__.py +6 -0
- toolbox/torchaudio/modules/local_snr_target.py +149 -0
examples/dfnet/run.sh
CHANGED
@@ -2,6 +2,9 @@
|
|
2 |
|
3 |
: <<'END'
|
4 |
|
|
|
|
|
|
|
5 |
|
6 |
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet-dns3 \
|
7 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
|
|
2 |
|
3 |
: <<'END'
|
4 |
|
5 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name dfnet-nx-speech \
|
6 |
+
--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
|
7 |
+
--speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech"
|
8 |
|
9 |
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet-dns3 \
|
10 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
examples/dfnet/step_1_prepare_data.py
CHANGED
@@ -104,7 +104,7 @@ def main():
|
|
104 |
dataset = list()
|
105 |
|
106 |
count = 0
|
107 |
-
process_bar = tqdm(desc="build dataset
|
108 |
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
109 |
for noise, speech in zip(noise_generator, speech_generator):
|
110 |
if count >= args.max_count > 0:
|
|
|
104 |
dataset = list()
|
105 |
|
106 |
count = 0
|
107 |
+
process_bar = tqdm(desc="build dataset jsonl")
|
108 |
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
109 |
for noise, speech in zip(noise_generator, speech_generator):
|
110 |
if count >= args.max_count > 0:
|
examples/dfnet/step_2_train_model.py
CHANGED
@@ -25,6 +25,8 @@ from tqdm import tqdm
|
|
25 |
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
26 |
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
27 |
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
|
|
|
|
28 |
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
29 |
from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
|
30 |
from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNet, DfNetPretrainedModel
|
@@ -79,22 +81,20 @@ class CollateFunction(object):
|
|
79 |
# noise_wave: torch.Tensor = sample["noise_wave"]
|
80 |
clean_audio: torch.Tensor = sample["speech_wave"]
|
81 |
noisy_audio: torch.Tensor = sample["mix_wave"]
|
82 |
-
snr_db: float = sample["snr_db"]
|
83 |
|
84 |
clean_audios.append(clean_audio)
|
85 |
noisy_audios.append(noisy_audio)
|
86 |
-
snr_db_list.append(snr_db)
|
87 |
|
88 |
clean_audios = torch.stack(clean_audios)
|
89 |
noisy_audios = torch.stack(noisy_audios)
|
90 |
-
snr_db_list = torch.tensor(snr_db_list)
|
91 |
|
92 |
# assert
|
93 |
if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
|
94 |
raise AssertionError("nan or inf in clean_audios")
|
95 |
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
|
96 |
raise AssertionError("nan or inf in noisy_audios")
|
97 |
-
return clean_audios, noisy_audios
|
98 |
|
99 |
|
100 |
collate_fn = CollateFunction()
|
@@ -146,7 +146,7 @@ def main():
|
|
146 |
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
147 |
collate_fn=collate_fn,
|
148 |
pin_memory=False,
|
149 |
-
prefetch_factor=2,
|
150 |
)
|
151 |
valid_data_loader = DataLoader(
|
152 |
dataset=valid_dataset,
|
@@ -157,7 +157,7 @@ def main():
|
|
157 |
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
158 |
collate_fn=collate_fn,
|
159 |
pin_memory=False,
|
160 |
-
prefetch_factor=2,
|
161 |
)
|
162 |
|
163 |
# models
|
@@ -222,7 +222,6 @@ def main():
|
|
222 |
factor_mag=1.0,
|
223 |
reduction="mean"
|
224 |
).to(device)
|
225 |
-
lsnr_loss_fn = nn.L1Loss(reduction="mean")
|
226 |
|
227 |
# training loop
|
228 |
|
@@ -247,8 +246,10 @@ def main():
|
|
247 |
|
248 |
total_pesq_score = 0.
|
249 |
total_loss = 0.
|
|
|
250 |
total_neg_si_snr_loss = 0.
|
251 |
total_mask_loss = 0.
|
|
|
252 |
total_batches = 0.
|
253 |
|
254 |
progress_bar_train = tqdm(
|
@@ -256,20 +257,18 @@ def main():
|
|
256 |
desc="Training; epoch-{}".format(epoch_idx),
|
257 |
)
|
258 |
for train_batch in train_data_loader:
|
259 |
-
clean_audios, noisy_audios
|
260 |
clean_audios: torch.Tensor = clean_audios.to(device)
|
261 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
262 |
-
snr_db_list: torch.Tensor = snr_db_list.to(device)
|
263 |
|
264 |
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
265 |
|
266 |
-
|
267 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
268 |
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
269 |
-
|
270 |
-
# neg_si_snr_loss = lsnr_loss_fn.forward(lsnr, snr_db_list)
|
271 |
|
272 |
-
loss = 1.0 * neg_si_snr_loss + 1.0 * mask_loss
|
273 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
274 |
logger.info(f"find nan or inf in loss.")
|
275 |
continue
|
@@ -286,22 +285,28 @@ def main():
|
|
286 |
|
287 |
total_pesq_score += pesq_score
|
288 |
total_loss += loss.item()
|
|
|
289 |
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
290 |
total_mask_loss += mask_loss.item()
|
|
|
291 |
total_batches += 1
|
292 |
|
293 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
294 |
average_loss = round(total_loss / total_batches, 4)
|
|
|
295 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
296 |
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
|
|
297 |
|
298 |
progress_bar_train.update(1)
|
299 |
progress_bar_train.set_postfix({
|
300 |
"lr": lr_scheduler.get_last_lr()[0],
|
301 |
"pesq_score": average_pesq_score,
|
302 |
"loss": average_loss,
|
|
|
303 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
304 |
"mask_loss": average_mask_loss,
|
|
|
305 |
})
|
306 |
|
307 |
# evaluation
|
@@ -312,8 +317,10 @@ def main():
|
|
312 |
|
313 |
total_pesq_score = 0.
|
314 |
total_loss = 0.
|
|
|
315 |
total_neg_si_snr_loss = 0.
|
316 |
total_mask_loss = 0.
|
|
|
317 |
total_batches = 0.
|
318 |
|
319 |
progress_bar_train.close()
|
@@ -321,17 +328,18 @@ def main():
|
|
321 |
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
322 |
)
|
323 |
for eval_batch in valid_data_loader:
|
324 |
-
clean_audios, noisy_audios
|
325 |
clean_audios: torch.Tensor = clean_audios.to(device)
|
326 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
327 |
-
snr_db_list: torch.Tensor = snr_db_list.to(device)
|
328 |
|
329 |
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
330 |
|
|
|
331 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
332 |
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
|
|
333 |
|
334 |
-
loss = 1.0 * neg_si_snr_loss + 1.0 * mask_loss
|
335 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
336 |
logger.info(f"find nan or inf in loss.")
|
337 |
continue
|
@@ -342,28 +350,36 @@ def main():
|
|
342 |
|
343 |
total_pesq_score += pesq_score
|
344 |
total_loss += loss.item()
|
|
|
345 |
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
346 |
total_mask_loss += mask_loss.item()
|
|
|
347 |
total_batches += 1
|
348 |
|
349 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
350 |
average_loss = round(total_loss / total_batches, 4)
|
|
|
351 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
352 |
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
|
|
353 |
|
354 |
progress_bar_eval.update(1)
|
355 |
progress_bar_eval.set_postfix({
|
356 |
"lr": lr_scheduler.get_last_lr()[0],
|
357 |
"pesq_score": average_pesq_score,
|
358 |
"loss": average_loss,
|
|
|
359 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
360 |
"mask_loss": average_mask_loss,
|
|
|
361 |
})
|
362 |
|
363 |
total_pesq_score = 0.
|
364 |
total_loss = 0.
|
|
|
365 |
total_neg_si_snr_loss = 0.
|
366 |
total_mask_loss = 0.
|
|
|
367 |
total_batches = 0.
|
368 |
|
369 |
progress_bar_eval.close()
|
@@ -393,7 +409,7 @@ def main():
|
|
393 |
best_epoch_idx = epoch_idx
|
394 |
best_step_idx = step_idx
|
395 |
best_metric = average_pesq_score
|
396 |
-
elif average_pesq_score
|
397 |
# great is better.
|
398 |
best_epoch_idx = epoch_idx
|
399 |
best_step_idx = step_idx
|
@@ -407,8 +423,10 @@ def main():
|
|
407 |
"best_step_idx": best_step_idx,
|
408 |
"pesq_score": average_pesq_score,
|
409 |
"loss": average_loss,
|
|
|
410 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
411 |
"mask_loss": average_mask_loss,
|
|
|
412 |
}
|
413 |
metrics_filename = save_dir / "metrics_epoch.json"
|
414 |
with open(metrics_filename, "w", encoding="utf-8") as f:
|
|
|
25 |
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
26 |
from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
|
27 |
from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
|
28 |
+
from toolbox.torchaudio.losses.irm import IRMLoss
|
29 |
+
from toolbox.torchaudio.losses.snr import LocalSNRLoss
|
30 |
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
31 |
from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
|
32 |
from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNet, DfNetPretrainedModel
|
|
|
81 |
# noise_wave: torch.Tensor = sample["noise_wave"]
|
82 |
clean_audio: torch.Tensor = sample["speech_wave"]
|
83 |
noisy_audio: torch.Tensor = sample["mix_wave"]
|
84 |
+
# snr_db: float = sample["snr_db"]
|
85 |
|
86 |
clean_audios.append(clean_audio)
|
87 |
noisy_audios.append(noisy_audio)
|
|
|
88 |
|
89 |
clean_audios = torch.stack(clean_audios)
|
90 |
noisy_audios = torch.stack(noisy_audios)
|
|
|
91 |
|
92 |
# assert
|
93 |
if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
|
94 |
raise AssertionError("nan or inf in clean_audios")
|
95 |
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
|
96 |
raise AssertionError("nan or inf in noisy_audios")
|
97 |
+
return clean_audios, noisy_audios
|
98 |
|
99 |
|
100 |
collate_fn = CollateFunction()
|
|
|
146 |
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
147 |
collate_fn=collate_fn,
|
148 |
pin_memory=False,
|
149 |
+
prefetch_factor=None if platform.system() == "Windows" else 2,
|
150 |
)
|
151 |
valid_data_loader = DataLoader(
|
152 |
dataset=valid_dataset,
|
|
|
157 |
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
158 |
collate_fn=collate_fn,
|
159 |
pin_memory=False,
|
160 |
+
prefetch_factor=None if platform.system() == "Windows" else 2,
|
161 |
)
|
162 |
|
163 |
# models
|
|
|
222 |
factor_mag=1.0,
|
223 |
reduction="mean"
|
224 |
).to(device)
|
|
|
225 |
|
226 |
# training loop
|
227 |
|
|
|
246 |
|
247 |
total_pesq_score = 0.
|
248 |
total_loss = 0.
|
249 |
+
total_mr_stft_loss = 0.
|
250 |
total_neg_si_snr_loss = 0.
|
251 |
total_mask_loss = 0.
|
252 |
+
total_lsnr_loss = 0.
|
253 |
total_batches = 0.
|
254 |
|
255 |
progress_bar_train = tqdm(
|
|
|
257 |
desc="Training; epoch-{}".format(epoch_idx),
|
258 |
)
|
259 |
for train_batch in train_data_loader:
|
260 |
+
clean_audios, noisy_audios = train_batch
|
261 |
clean_audios: torch.Tensor = clean_audios.to(device)
|
262 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
|
|
263 |
|
264 |
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
265 |
|
266 |
+
mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
|
267 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
268 |
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
269 |
+
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
|
|
270 |
|
271 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 1.0 * lsnr_loss
|
272 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
273 |
logger.info(f"find nan or inf in loss.")
|
274 |
continue
|
|
|
285 |
|
286 |
total_pesq_score += pesq_score
|
287 |
total_loss += loss.item()
|
288 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
289 |
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
290 |
total_mask_loss += mask_loss.item()
|
291 |
+
total_lsnr_loss += lsnr_loss.item()
|
292 |
total_batches += 1
|
293 |
|
294 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
295 |
average_loss = round(total_loss / total_batches, 4)
|
296 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
297 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
298 |
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
299 |
+
average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
|
300 |
|
301 |
progress_bar_train.update(1)
|
302 |
progress_bar_train.set_postfix({
|
303 |
"lr": lr_scheduler.get_last_lr()[0],
|
304 |
"pesq_score": average_pesq_score,
|
305 |
"loss": average_loss,
|
306 |
+
"mr_stft_loss": average_mr_stft_loss,
|
307 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
308 |
"mask_loss": average_mask_loss,
|
309 |
+
"lsnr_loss": average_lsnr_loss,
|
310 |
})
|
311 |
|
312 |
# evaluation
|
|
|
317 |
|
318 |
total_pesq_score = 0.
|
319 |
total_loss = 0.
|
320 |
+
total_mr_stft_loss = 0.
|
321 |
total_neg_si_snr_loss = 0.
|
322 |
total_mask_loss = 0.
|
323 |
+
total_lsnr_loss = 0.
|
324 |
total_batches = 0.
|
325 |
|
326 |
progress_bar_train.close()
|
|
|
328 |
desc="Evaluation; steps-{}k".format(int(step_idx/1000)),
|
329 |
)
|
330 |
for eval_batch in valid_data_loader:
|
331 |
+
clean_audios, noisy_audios = eval_batch
|
332 |
clean_audios: torch.Tensor = clean_audios.to(device)
|
333 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
|
|
334 |
|
335 |
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
336 |
|
337 |
+
mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
|
338 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
339 |
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
340 |
+
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
341 |
|
342 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 1.0 * lsnr_loss
|
343 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
344 |
logger.info(f"find nan or inf in loss.")
|
345 |
continue
|
|
|
350 |
|
351 |
total_pesq_score += pesq_score
|
352 |
total_loss += loss.item()
|
353 |
+
total_mr_stft_loss += mr_stft_loss.item()
|
354 |
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
355 |
total_mask_loss += mask_loss.item()
|
356 |
+
total_lsnr_loss += lsnr_loss.item()
|
357 |
total_batches += 1
|
358 |
|
359 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
360 |
average_loss = round(total_loss / total_batches, 4)
|
361 |
+
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
362 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
363 |
average_mask_loss = round(total_mask_loss / total_batches, 4)
|
364 |
+
average_lsnr_loss = round(total_lsnr_loss / total_batches, 4)
|
365 |
|
366 |
progress_bar_eval.update(1)
|
367 |
progress_bar_eval.set_postfix({
|
368 |
"lr": lr_scheduler.get_last_lr()[0],
|
369 |
"pesq_score": average_pesq_score,
|
370 |
"loss": average_loss,
|
371 |
+
"mr_stft_loss": average_mr_stft_loss,
|
372 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
373 |
"mask_loss": average_mask_loss,
|
374 |
+
"lsnr_loss": average_lsnr_loss,
|
375 |
})
|
376 |
|
377 |
total_pesq_score = 0.
|
378 |
total_loss = 0.
|
379 |
+
total_mr_stft_loss = 0.
|
380 |
total_neg_si_snr_loss = 0.
|
381 |
total_mask_loss = 0.
|
382 |
+
total_lsnr_loss = 0.
|
383 |
total_batches = 0.
|
384 |
|
385 |
progress_bar_eval.close()
|
|
|
409 |
best_epoch_idx = epoch_idx
|
410 |
best_step_idx = step_idx
|
411 |
best_metric = average_pesq_score
|
412 |
+
elif average_pesq_score >= best_metric:
|
413 |
# great is better.
|
414 |
best_epoch_idx = epoch_idx
|
415 |
best_step_idx = step_idx
|
|
|
423 |
"best_step_idx": best_step_idx,
|
424 |
"pesq_score": average_pesq_score,
|
425 |
"loss": average_loss,
|
426 |
+
"mr_stft_loss": average_mr_stft_loss,
|
427 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
428 |
"mask_loss": average_mask_loss,
|
429 |
+
"lsnr_loss": average_lsnr_loss,
|
430 |
}
|
431 |
metrics_filename = save_dir / "metrics_epoch.json"
|
432 |
with open(metrics_filename, "w", encoding="utf-8") as f:
|
examples/dfnet/yaml/config.yaml
CHANGED
@@ -31,10 +31,6 @@ encoder_emb_hidden_size: 256
|
|
31 |
|
32 |
encoder_linear_groups: 32
|
33 |
|
34 |
-
lsnr_max: 30
|
35 |
-
lsnr_min: -15
|
36 |
-
norm_tau: 1.
|
37 |
-
|
38 |
decoder_emb_num_layers: 3
|
39 |
decoder_emb_skip_op: "none"
|
40 |
decoder_emb_linear_groups: 16
|
@@ -49,8 +45,15 @@ df_decoder_linear_groups: 16
|
|
49 |
df_pathway_kernel_size_t: 5
|
50 |
df_lookahead: 2
|
51 |
|
52 |
-
#
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
# train
|
56 |
lr: 0.001
|
@@ -63,9 +66,9 @@ max_epochs: 100
|
|
63 |
clip_grad_norm: 10.0
|
64 |
seed: 1234
|
65 |
|
66 |
-
min_snr_db: -10
|
67 |
-
max_snr_db: 20
|
68 |
-
|
69 |
num_workers: 8
|
70 |
-
batch_size:
|
71 |
eval_steps: 10000
|
|
|
|
|
|
|
|
31 |
|
32 |
encoder_linear_groups: 32
|
33 |
|
|
|
|
|
|
|
|
|
34 |
decoder_emb_num_layers: 3
|
35 |
decoder_emb_skip_op: "none"
|
36 |
decoder_emb_linear_groups: 16
|
|
|
45 |
df_pathway_kernel_size_t: 5
|
46 |
df_lookahead: 2
|
47 |
|
48 |
+
# lsnr
|
49 |
+
n_frame: 3
|
50 |
+
lsnr_max: 30
|
51 |
+
lsnr_min: -15
|
52 |
+
norm_tau: 1.
|
53 |
+
|
54 |
+
# data
|
55 |
+
min_snr_db: -10
|
56 |
+
max_snr_db: 20
|
57 |
|
58 |
# train
|
59 |
lr: 0.001
|
|
|
66 |
clip_grad_norm: 10.0
|
67 |
seed: 1234
|
68 |
|
|
|
|
|
|
|
69 |
num_workers: 8
|
70 |
+
batch_size: 4
|
71 |
eval_steps: 10000
|
72 |
+
|
73 |
+
# runtime
|
74 |
+
use_post_filter: true
|
examples/frcrn/step_1_prepare_data.py
CHANGED
@@ -104,7 +104,7 @@ def main():
|
|
104 |
dataset = list()
|
105 |
|
106 |
count = 0
|
107 |
-
process_bar = tqdm(desc="build dataset
|
108 |
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
109 |
for noise, speech in zip(noise_generator, speech_generator):
|
110 |
flag = random.random()
|
|
|
104 |
dataset = list()
|
105 |
|
106 |
count = 0
|
107 |
+
process_bar = tqdm(desc="build dataset jsonl")
|
108 |
with open(args.train_dataset, "w", encoding="utf-8") as ftrain, open(args.valid_dataset, "w", encoding="utf-8") as fvalid:
|
109 |
for noise, speech in zip(noise_generator, speech_generator):
|
110 |
flag = random.random()
|
toolbox/torchaudio/losses/irm.py
CHANGED
@@ -93,6 +93,69 @@ class CIRMLoss(nn.Module):
|
|
93 |
return loss
|
94 |
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
def main():
|
97 |
batch_size = 2
|
98 |
signal_length = 16000
|
|
|
93 |
return loss
|
94 |
|
95 |
|
96 |
+
class IRMLoss(nn.Module):
|
97 |
+
"""
|
98 |
+
https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/loss.py#L25
|
99 |
+
"""
|
100 |
+
def __init__(self,
|
101 |
+
n_fft: int = 512,
|
102 |
+
win_size: int = 512,
|
103 |
+
hop_size: int = 256,
|
104 |
+
center: bool = True,
|
105 |
+
eps: float = 1e-8,
|
106 |
+
reduction: str = "mean",
|
107 |
+
):
|
108 |
+
super(IRMLoss, self).__init__()
|
109 |
+
self.n_fft = n_fft
|
110 |
+
self.win_size = win_size
|
111 |
+
self.hop_size = hop_size
|
112 |
+
self.center = center
|
113 |
+
self.eps = eps
|
114 |
+
self.reduction = reduction
|
115 |
+
|
116 |
+
self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False)
|
117 |
+
|
118 |
+
if reduction not in ("sum", "mean"):
|
119 |
+
raise AssertionError(f"param reduction must be sum or mean.")
|
120 |
+
|
121 |
+
def forward(self, mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
|
122 |
+
if noisy.shape != clean.shape:
|
123 |
+
raise AssertionError("Input signals must have the same shape")
|
124 |
+
noise = noisy - clean
|
125 |
+
|
126 |
+
# clean_stft, noisy_stft shape: [b, f, t]
|
127 |
+
stft_clean = torch.stft(
|
128 |
+
clean,
|
129 |
+
n_fft=self.n_fft,
|
130 |
+
win_length=self.win_size,
|
131 |
+
hop_length=self.hop_size,
|
132 |
+
window=self.window,
|
133 |
+
center=self.center,
|
134 |
+
pad_mode="reflect",
|
135 |
+
normalized=False,
|
136 |
+
return_complex=True
|
137 |
+
)
|
138 |
+
stft_noise = torch.stft(
|
139 |
+
noise,
|
140 |
+
n_fft=self.n_fft,
|
141 |
+
win_length=self.win_size,
|
142 |
+
hop_length=self.hop_size,
|
143 |
+
window=self.window,
|
144 |
+
center=self.center,
|
145 |
+
pad_mode="reflect",
|
146 |
+
normalized=False,
|
147 |
+
return_complex=True
|
148 |
+
)
|
149 |
+
|
150 |
+
mag_clean = torch.abs(stft_clean)
|
151 |
+
mag_noise = torch.abs(stft_noise)
|
152 |
+
|
153 |
+
gth_irm_mask = (mag_clean / (mag_clean + mag_noise + self.eps)).clamp(0, 1)
|
154 |
+
|
155 |
+
loss = F.l1_loss(gth_irm_mask, mask, reduction=self.reduction)
|
156 |
+
return loss
|
157 |
+
|
158 |
+
|
159 |
def main():
|
160 |
batch_size = 2
|
161 |
signal_length = 16000
|
toolbox/torchaudio/losses/snr.py
CHANGED
@@ -5,6 +5,9 @@ https://zhuanlan.zhihu.com/p/627039860
|
|
5 |
"""
|
6 |
import torch
|
7 |
import torch.nn as nn
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
class NegativeSNRLoss(nn.Module):
|
@@ -83,6 +86,86 @@ class NegativeSISNRLoss(nn.Module):
|
|
83 |
return -loss
|
84 |
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
def main():
|
87 |
batch_size = 2
|
88 |
signal_length = 16000
|
|
|
5 |
"""
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
|
11 |
|
12 |
|
13 |
class NegativeSNRLoss(nn.Module):
|
|
|
86 |
return -loss
|
87 |
|
88 |
|
89 |
+
class LocalSNRLoss(nn.Module):
|
90 |
+
"""
|
91 |
+
https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/modules.py#L816
|
92 |
+
|
93 |
+
"""
|
94 |
+
def __init__(self,
|
95 |
+
sample_rate: int = 8000,
|
96 |
+
nfft: int = 512,
|
97 |
+
win_size: int = 512,
|
98 |
+
hop_size: int = 256,
|
99 |
+
n_frame: int = 3,
|
100 |
+
min_local_snr: int = -15,
|
101 |
+
max_local_snr: int = 30,
|
102 |
+
db: bool = True,
|
103 |
+
factor: float = 1,
|
104 |
+
reduction: str = "mean",
|
105 |
+
eps: float = 1e-8,
|
106 |
+
):
|
107 |
+
super(LocalSNRLoss, self).__init__()
|
108 |
+
self.sample_rate = sample_rate
|
109 |
+
self.nfft = nfft
|
110 |
+
self.win_size = win_size
|
111 |
+
self.hop_size = hop_size
|
112 |
+
|
113 |
+
self.factor = factor
|
114 |
+
self.reduction = reduction
|
115 |
+
self.eps = eps
|
116 |
+
|
117 |
+
self.lsnr_fn = LocalSnrTarget(
|
118 |
+
sample_rate=sample_rate,
|
119 |
+
nfft=nfft,
|
120 |
+
win_size=win_size,
|
121 |
+
hop_size=hop_size,
|
122 |
+
n_frame=n_frame,
|
123 |
+
min_local_snr=min_local_snr,
|
124 |
+
max_local_snr=max_local_snr,
|
125 |
+
db=db,
|
126 |
+
)
|
127 |
+
|
128 |
+
self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False)
|
129 |
+
|
130 |
+
def forward(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
|
131 |
+
if clean.shape != noisy.shape:
|
132 |
+
raise AssertionError("Input signals must have the same shape")
|
133 |
+
noise = noisy - clean
|
134 |
+
|
135 |
+
stft_clean = torch.stft(
|
136 |
+
clean,
|
137 |
+
n_fft=self.nfft,
|
138 |
+
win_length=self.win_size,
|
139 |
+
hop_length=self.hop_size,
|
140 |
+
window=self.window,
|
141 |
+
center=self.center,
|
142 |
+
pad_mode="reflect",
|
143 |
+
normalized=False,
|
144 |
+
return_complex=True
|
145 |
+
)
|
146 |
+
stft_noise = torch.stft(
|
147 |
+
noise,
|
148 |
+
n_fft=self.nfft,
|
149 |
+
win_length=self.win_size,
|
150 |
+
hop_length=self.hop_size,
|
151 |
+
window=self.window,
|
152 |
+
center=self.center,
|
153 |
+
pad_mode="reflect",
|
154 |
+
normalized=False,
|
155 |
+
return_complex=True
|
156 |
+
)
|
157 |
+
|
158 |
+
# lsnr shape: [b, 1, t]
|
159 |
+
lsnr = lsnr.squeeze(1)
|
160 |
+
# lsnr shape: [b, t]
|
161 |
+
|
162 |
+
lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
|
163 |
+
# lsnr_gth shape: [b, t]
|
164 |
+
|
165 |
+
loss = F.mse_loss(lsnr, lsnr_gth) * self.factor
|
166 |
+
return loss
|
167 |
+
|
168 |
+
|
169 |
def main():
|
170 |
batch_size = 2
|
171 |
signal_length = 16000
|
toolbox/torchaudio/models/dfnet/configuration_dfnet.py
CHANGED
@@ -31,10 +31,6 @@ class DfNetConfig(PretrainedConfig):
|
|
31 |
|
32 |
encoder_linear_groups: int = 32,
|
33 |
|
34 |
-
lsnr_max: int = 30,
|
35 |
-
lsnr_min: int = -15,
|
36 |
-
norm_tau: float = 1.,
|
37 |
-
|
38 |
decoder_emb_num_layers: int = 3,
|
39 |
decoder_emb_skip_op: str = "none",
|
40 |
decoder_emb_linear_groups: int = 16,
|
@@ -49,7 +45,13 @@ class DfNetConfig(PretrainedConfig):
|
|
49 |
df_pathway_kernel_size_t: int = 5,
|
50 |
df_lookahead: int = 2,
|
51 |
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
lr: float = 0.001,
|
55 |
lr_scheduler: str = "CosineAnnealingLR",
|
@@ -59,13 +61,12 @@ class DfNetConfig(PretrainedConfig):
|
|
59 |
clip_grad_norm: float = 10.,
|
60 |
seed: int = 1234,
|
61 |
|
62 |
-
min_snr_db: float = -10,
|
63 |
-
max_snr_db: float = 20,
|
64 |
-
|
65 |
num_workers: int = 4,
|
66 |
batch_size: int = 4,
|
67 |
eval_steps: int = 25000,
|
68 |
|
|
|
|
|
69 |
**kwargs
|
70 |
):
|
71 |
super(DfNetConfig, self).__init__(**kwargs)
|
@@ -97,10 +98,6 @@ class DfNetConfig(PretrainedConfig):
|
|
97 |
self.encoder_linear_groups = encoder_linear_groups
|
98 |
self.encoder_combine_op = encoder_combine_op
|
99 |
|
100 |
-
self.lsnr_max = lsnr_max
|
101 |
-
self.lsnr_min = lsnr_min
|
102 |
-
self.norm_tau = norm_tau
|
103 |
-
|
104 |
# decoder
|
105 |
self.decoder_emb_num_layers = decoder_emb_num_layers
|
106 |
self.decoder_emb_skip_op = decoder_emb_skip_op
|
@@ -117,10 +114,17 @@ class DfNetConfig(PretrainedConfig):
|
|
117 |
self.df_pathway_kernel_size_t = df_pathway_kernel_size_t
|
118 |
self.df_lookahead = df_lookahead
|
119 |
|
120 |
-
#
|
121 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
-
#
|
124 |
self.lr = lr
|
125 |
self.lr_scheduler = lr_scheduler
|
126 |
self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
|
@@ -129,13 +133,13 @@ class DfNetConfig(PretrainedConfig):
|
|
129 |
self.clip_grad_norm = clip_grad_norm
|
130 |
self.seed = seed
|
131 |
|
132 |
-
self.min_snr_db = min_snr_db
|
133 |
-
self.max_snr_db = max_snr_db
|
134 |
-
|
135 |
self.num_workers = num_workers
|
136 |
self.batch_size = batch_size
|
137 |
self.eval_steps = eval_steps
|
138 |
|
|
|
|
|
|
|
139 |
|
140 |
if __name__ == "__main__":
|
141 |
pass
|
|
|
31 |
|
32 |
encoder_linear_groups: int = 32,
|
33 |
|
|
|
|
|
|
|
|
|
34 |
decoder_emb_num_layers: int = 3,
|
35 |
decoder_emb_skip_op: str = "none",
|
36 |
decoder_emb_linear_groups: int = 16,
|
|
|
45 |
df_pathway_kernel_size_t: int = 5,
|
46 |
df_lookahead: int = 2,
|
47 |
|
48 |
+
n_frame: int = 3,
|
49 |
+
max_local_snr: int = 30,
|
50 |
+
min_local_snr: int = -15,
|
51 |
+
norm_tau: float = 1.,
|
52 |
+
|
53 |
+
min_snr_db: float = -10,
|
54 |
+
max_snr_db: float = 20,
|
55 |
|
56 |
lr: float = 0.001,
|
57 |
lr_scheduler: str = "CosineAnnealingLR",
|
|
|
61 |
clip_grad_norm: float = 10.,
|
62 |
seed: int = 1234,
|
63 |
|
|
|
|
|
|
|
64 |
num_workers: int = 4,
|
65 |
batch_size: int = 4,
|
66 |
eval_steps: int = 25000,
|
67 |
|
68 |
+
use_post_filter: bool = False,
|
69 |
+
|
70 |
**kwargs
|
71 |
):
|
72 |
super(DfNetConfig, self).__init__(**kwargs)
|
|
|
98 |
self.encoder_linear_groups = encoder_linear_groups
|
99 |
self.encoder_combine_op = encoder_combine_op
|
100 |
|
|
|
|
|
|
|
|
|
101 |
# decoder
|
102 |
self.decoder_emb_num_layers = decoder_emb_num_layers
|
103 |
self.decoder_emb_skip_op = decoder_emb_skip_op
|
|
|
114 |
self.df_pathway_kernel_size_t = df_pathway_kernel_size_t
|
115 |
self.df_lookahead = df_lookahead
|
116 |
|
117 |
+
# lsnr
|
118 |
+
self.n_frame = n_frame
|
119 |
+
self.max_local_snr = max_local_snr
|
120 |
+
self.min_local_snr = min_local_snr
|
121 |
+
self.norm_tau = norm_tau
|
122 |
+
|
123 |
+
# data snr
|
124 |
+
self.min_snr_db = min_snr_db
|
125 |
+
self.max_snr_db = max_snr_db
|
126 |
|
127 |
+
# train
|
128 |
self.lr = lr
|
129 |
self.lr_scheduler = lr_scheduler
|
130 |
self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
|
|
|
133 |
self.clip_grad_norm = clip_grad_norm
|
134 |
self.seed = seed
|
135 |
|
|
|
|
|
|
|
136 |
self.num_workers = num_workers
|
137 |
self.batch_size = batch_size
|
138 |
self.eval_steps = eval_steps
|
139 |
|
140 |
+
# runtime
|
141 |
+
self.use_post_filter = use_post_filter
|
142 |
+
|
143 |
|
144 |
if __name__ == "__main__":
|
145 |
pass
|
toolbox/torchaudio/models/dfnet/conv_stft.py
CHANGED
@@ -8,6 +8,7 @@ import torch
|
|
8 |
import torch.nn as nn
|
9 |
import torch.nn.functional as F
|
10 |
from scipy.signal import get_window
|
|
|
11 |
|
12 |
|
13 |
def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False):
|
@@ -40,7 +41,7 @@ class ConvSTFT(nn.Module):
|
|
40 |
win_size: int,
|
41 |
hop_size: int,
|
42 |
win_type: str = "hamming",
|
43 |
-
|
44 |
requires_grad: bool = False):
|
45 |
super(ConvSTFT, self).__init__()
|
46 |
|
@@ -57,23 +58,29 @@ class ConvSTFT(nn.Module):
|
|
57 |
|
58 |
self.stride = hop_size
|
59 |
self.dim = self.nfft
|
60 |
-
self.
|
61 |
|
62 |
def forward(self, inputs: torch.Tensor):
|
63 |
if inputs.dim() == 2:
|
64 |
inputs = torch.unsqueeze(inputs, 1)
|
65 |
|
66 |
-
|
|
|
|
|
|
|
|
|
67 |
|
68 |
-
if self.
|
69 |
-
return
|
70 |
-
|
71 |
-
dim = self.dim // 2 + 1
|
72 |
-
real = outputs[:, :dim, :]
|
73 |
-
imag = outputs[:, dim:, :]
|
74 |
mags = torch.sqrt(real**2 + imag**2)
|
75 |
-
phase = torch.atan2(imag, real)
|
76 |
-
return mags
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
|
79 |
class ConviSTFT(nn.Module):
|
@@ -83,7 +90,6 @@ class ConviSTFT(nn.Module):
|
|
83 |
hop_size: int,
|
84 |
nfft: int = None,
|
85 |
win_type: str = "hamming",
|
86 |
-
feature_type: str = "real",
|
87 |
requires_grad: bool = False):
|
88 |
super(ConviSTFT, self).__init__()
|
89 |
if nfft is None:
|
@@ -100,45 +106,41 @@ class ConviSTFT(nn.Module):
|
|
100 |
|
101 |
self.stride = hop_size
|
102 |
self.dim = self.nfft
|
103 |
-
self.feature_type = feature_type
|
104 |
|
105 |
self.register_buffer("window", window)
|
106 |
self.register_buffer("enframe", torch.eye(win_size)[:, None, :])
|
107 |
|
108 |
def forward(self,
|
109 |
-
inputs: torch.Tensor
|
110 |
-
phase: torch.Tensor = None):
|
111 |
"""
|
112 |
-
:param inputs: torch.Tensor, shape: [b,
|
113 |
-
:param phase: torch.Tensor, shape: [b, n//2+1, t]
|
114 |
:return:
|
115 |
"""
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
|
121 |
|
122 |
# this is from torch-stft: https://github.com/pseeth/torch-stft
|
123 |
-
t = self.window.repeat(1, 1,
|
124 |
coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
|
125 |
-
|
126 |
-
return
|
127 |
|
128 |
|
129 |
def main():
|
130 |
-
stft = ConvSTFT(win_size=512, hop_size=200,
|
131 |
-
istft = ConviSTFT(win_size=512, hop_size=200
|
132 |
|
133 |
mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32)
|
134 |
|
135 |
spec = stft.forward(mixture)
|
136 |
# shape: [batch_size, freq_bins, time_steps]
|
137 |
-
print(spec.shape)
|
138 |
|
139 |
waveform = istft.forward(spec)
|
140 |
# shape: [batch_size, channels, num_samples]
|
141 |
-
print(waveform.shape)
|
142 |
|
143 |
return
|
144 |
|
|
|
8 |
import torch.nn as nn
|
9 |
import torch.nn.functional as F
|
10 |
from scipy.signal import get_window
|
11 |
+
from sympy.physics.units import power
|
12 |
|
13 |
|
14 |
def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False):
|
|
|
41 |
win_size: int,
|
42 |
hop_size: int,
|
43 |
win_type: str = "hamming",
|
44 |
+
power: int = None,
|
45 |
requires_grad: bool = False):
|
46 |
super(ConvSTFT, self).__init__()
|
47 |
|
|
|
58 |
|
59 |
self.stride = hop_size
|
60 |
self.dim = self.nfft
|
61 |
+
self.power = power
|
62 |
|
63 |
def forward(self, inputs: torch.Tensor):
|
64 |
if inputs.dim() == 2:
|
65 |
inputs = torch.unsqueeze(inputs, 1)
|
66 |
|
67 |
+
matrix = F.conv1d(inputs, self.weight, stride=self.stride)
|
68 |
+
dim = self.dim // 2 + 1
|
69 |
+
real = matrix[:, :dim, :]
|
70 |
+
imag = matrix[:, dim:, :]
|
71 |
+
spec = torch.complex(real, imag)
|
72 |
|
73 |
+
if self.power is None:
|
74 |
+
return spec
|
75 |
+
elif self.power == 1:
|
|
|
|
|
|
|
76 |
mags = torch.sqrt(real**2 + imag**2)
|
77 |
+
# phase = torch.atan2(imag, real)
|
78 |
+
return mags
|
79 |
+
elif self.power == 2:
|
80 |
+
power = real**2 + imag**2
|
81 |
+
return power
|
82 |
+
else:
|
83 |
+
raise AssertionError
|
84 |
|
85 |
|
86 |
class ConviSTFT(nn.Module):
|
|
|
90 |
hop_size: int,
|
91 |
nfft: int = None,
|
92 |
win_type: str = "hamming",
|
|
|
93 |
requires_grad: bool = False):
|
94 |
super(ConviSTFT, self).__init__()
|
95 |
if nfft is None:
|
|
|
106 |
|
107 |
self.stride = hop_size
|
108 |
self.dim = self.nfft
|
|
|
109 |
|
110 |
self.register_buffer("window", window)
|
111 |
self.register_buffer("enframe", torch.eye(win_size)[:, None, :])
|
112 |
|
113 |
def forward(self,
|
114 |
+
inputs: torch.Tensor):
|
|
|
115 |
"""
|
116 |
+
:param inputs: torch.Tensor, shape: [b, f, t]
|
|
|
117 |
:return:
|
118 |
"""
|
119 |
+
inputs = torch.view_as_real(inputs)
|
120 |
+
matrix = torch.concat(tensors=[inputs[..., 0], inputs[..., 1]], dim=1)
|
121 |
+
|
122 |
+
waveform = F.conv_transpose1d(matrix, self.weight, stride=self.stride)
|
|
|
123 |
|
124 |
# this is from torch-stft: https://github.com/pseeth/torch-stft
|
125 |
+
t = self.window.repeat(1, 1, matrix.size(-1))**2
|
126 |
coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
|
127 |
+
waveform = waveform / (coff + 1e-8)
|
128 |
+
return waveform
|
129 |
|
130 |
|
131 |
def main():
|
132 |
+
stft = ConvSTFT(nfft=512, win_size=512, hop_size=200, power=None)
|
133 |
+
istft = ConviSTFT(nfft=512, win_size=512, hop_size=200)
|
134 |
|
135 |
mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32)
|
136 |
|
137 |
spec = stft.forward(mixture)
|
138 |
# shape: [batch_size, freq_bins, time_steps]
|
139 |
+
print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
|
140 |
|
141 |
waveform = istft.forward(spec)
|
142 |
# shape: [batch_size, channels, num_samples]
|
143 |
+
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
|
144 |
|
145 |
return
|
146 |
|
toolbox/torchaudio/models/dfnet/modeling_dfnet.py
CHANGED
@@ -13,6 +13,7 @@ import torchaudio
|
|
13 |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
14 |
from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
|
15 |
from toolbox.torchaudio.models.dfnet.conv_stft import ConvSTFT, ConviSTFT
|
|
|
16 |
|
17 |
|
18 |
MODEL_FILE = "model.pt"
|
@@ -415,8 +416,8 @@ class Encoder(nn.Module):
|
|
415 |
nn.Linear(self.embedding_output_size, 1),
|
416 |
nn.Sigmoid()
|
417 |
)
|
418 |
-
self.lsnr_scale = config.
|
419 |
-
self.lsnr_offset = config.
|
420 |
|
421 |
def forward(self,
|
422 |
feat_power: torch.Tensor,
|
@@ -789,7 +790,7 @@ class DfNet(nn.Module):
|
|
789 |
def __init__(self, config: DfNetConfig):
|
790 |
super(DfNet, self).__init__()
|
791 |
self.config = config
|
792 |
-
self.eps = 1e-
|
793 |
|
794 |
self.freq_bins = self.config.nfft // 2 + 1
|
795 |
|
@@ -803,7 +804,7 @@ class DfNet(nn.Module):
|
|
803 |
win_size=config.win_size,
|
804 |
hop_size=config.hop_size,
|
805 |
win_type=config.win_type,
|
806 |
-
|
807 |
requires_grad=False
|
808 |
)
|
809 |
self.istft = ConviSTFT(
|
@@ -811,7 +812,6 @@ class DfNet(nn.Module):
|
|
811 |
win_size=config.win_size,
|
812 |
hop_size=config.hop_size,
|
813 |
win_type=config.win_type,
|
814 |
-
feature_type="complex",
|
815 |
requires_grad=False
|
816 |
)
|
817 |
|
@@ -828,98 +828,121 @@ class DfNet(nn.Module):
|
|
828 |
|
829 |
self.mask = Mask(use_post_filter=config.use_post_filter)
|
830 |
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
837 |
remainder = (n_samples - self.win_size) % self.hop_size
|
838 |
if remainder > 0:
|
839 |
n_samples_pad = self.hop_size - remainder
|
840 |
-
|
|
|
841 |
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
846 |
|
847 |
-
#
|
848 |
-
cmp_spec =
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
|
|
853 |
cmp_spec = cmp_spec[:, :, :-1, :]
|
|
|
|
|
854 |
|
855 |
spec = torch.unsqueeze(cmp_spec, dim=4)
|
856 |
-
# [
|
857 |
spec = spec.permute(0, 4, 3, 2, 1)
|
858 |
-
# spec shape: [
|
859 |
|
860 |
feat_power = torch.sum(torch.square(spec), dim=-1)
|
861 |
-
# feat_power shape: [
|
862 |
|
863 |
feat_spec = torch.transpose(cmp_spec, dim0=2, dim1=3)
|
864 |
-
# feat_spec shape: [
|
865 |
feat_spec = feat_spec[..., :self.df_decoder.df_bins]
|
866 |
-
# feat_spec shape: [
|
867 |
|
868 |
e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
|
869 |
|
870 |
mask = self.decoder.forward(emb, e3, e2, e1, e0)
|
871 |
-
# mask shape: [
|
872 |
if torch.any(mask > 1) or torch.any(mask < 0):
|
873 |
raise AssertionError
|
874 |
|
875 |
spec_m = self.mask.forward(spec, mask)
|
876 |
|
877 |
-
# lsnr shape: [
|
878 |
lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
|
879 |
-
# lsnr shape: [
|
880 |
|
881 |
df_coefs = self.df_decoder.forward(emb, c0)
|
882 |
df_coefs = self.df_out_transform(df_coefs)
|
883 |
-
# df_coefs shape: [
|
884 |
|
885 |
spec_e = self.df_op.forward(spec.clone(), df_coefs)
|
886 |
-
# est_spec shape: [
|
887 |
|
888 |
spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
|
889 |
|
890 |
spec_e = torch.squeeze(spec_e, dim=1)
|
891 |
spec_e = spec_e.permute(0, 2, 1, 3)
|
892 |
-
# spec_e shape: [
|
893 |
|
894 |
mask = torch.squeeze(mask, dim=1)
|
895 |
mask = mask.permute(0, 2, 1)
|
896 |
-
# mask shape: [b,
|
897 |
est_mask = self.mask_transfer(mask)
|
898 |
-
# est_mask shape: [b,
|
899 |
|
900 |
-
# spec_e shape: [b,
|
901 |
est_spec = self.spec_transfer(spec_e)
|
902 |
-
# est_spec shape: [b,
|
|
|
903 |
est_wav = self.istft.forward(est_spec)
|
904 |
est_wav = torch.squeeze(est_wav, dim=1)
|
905 |
est_wav = est_wav[:, :n_samples]
|
906 |
# est_wav shape: [b, n_samples]
|
|
|
907 |
return est_spec, est_wav, est_mask, lsnr
|
908 |
|
909 |
def spec_transfer(self, spec_e: torch.Tensor) -> torch.Tensor:
|
910 |
-
# spec_e shape: [b,
|
911 |
b, _, t, _ = spec_e.shape
|
912 |
-
est_spec = torch.
|
913 |
-
torch.concat(tensors=[
|
914 |
spec_e[..., 0],
|
915 |
torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
|
916 |
], dim=1),
|
917 |
-
torch.concat(tensors=[
|
918 |
spec_e[..., 1],
|
919 |
torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
|
920 |
], dim=1),
|
921 |
-
|
922 |
-
# est_spec shape: [b,
|
923 |
return est_spec
|
924 |
|
925 |
def mask_transfer(self, mask: torch.Tensor) -> torch.Tensor:
|
@@ -934,29 +957,58 @@ class DfNet(nn.Module):
|
|
934 |
|
935 |
def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
|
936 |
"""
|
937 |
-
|
938 |
-
:param est_mask: torch.Tensor, shape: [b, n+2, t]
|
939 |
:param clean:
|
940 |
:param noisy:
|
941 |
:return:
|
942 |
"""
|
943 |
-
|
944 |
-
|
945 |
-
|
946 |
-
|
|
|
|
|
947 |
|
948 |
-
|
949 |
-
|
950 |
-
noisy_im = noisy_stft[:, self.freq_bins:, :]
|
951 |
-
noisy_power = noisy_re ** 2 + noisy_im ** 2
|
952 |
|
953 |
-
|
954 |
-
|
955 |
|
956 |
-
|
|
|
|
|
957 |
|
958 |
return loss
|
959 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
960 |
|
961 |
class DfNetPretrainedModel(DfNet):
|
962 |
def __init__(self,
|
@@ -1011,8 +1063,12 @@ def main():
|
|
1011 |
|
1012 |
noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
|
1013 |
|
1014 |
-
|
1015 |
-
print(
|
|
|
|
|
|
|
|
|
1016 |
return
|
1017 |
|
1018 |
|
|
|
13 |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
14 |
from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
|
15 |
from toolbox.torchaudio.models.dfnet.conv_stft import ConvSTFT, ConviSTFT
|
16 |
+
from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
|
17 |
|
18 |
|
19 |
MODEL_FILE = "model.pt"
|
|
|
416 |
nn.Linear(self.embedding_output_size, 1),
|
417 |
nn.Sigmoid()
|
418 |
)
|
419 |
+
self.lsnr_scale = config.max_local_snr - config.min_local_snr
|
420 |
+
self.lsnr_offset = config.min_local_snr
|
421 |
|
422 |
def forward(self,
|
423 |
feat_power: torch.Tensor,
|
|
|
790 |
def __init__(self, config: DfNetConfig):
|
791 |
super(DfNet, self).__init__()
|
792 |
self.config = config
|
793 |
+
self.eps = 1e-12
|
794 |
|
795 |
self.freq_bins = self.config.nfft // 2 + 1
|
796 |
|
|
|
804 |
win_size=config.win_size,
|
805 |
hop_size=config.hop_size,
|
806 |
win_type=config.win_type,
|
807 |
+
power=None,
|
808 |
requires_grad=False
|
809 |
)
|
810 |
self.istft = ConviSTFT(
|
|
|
812 |
win_size=config.win_size,
|
813 |
hop_size=config.hop_size,
|
814 |
win_type=config.win_type,
|
|
|
815 |
requires_grad=False
|
816 |
)
|
817 |
|
|
|
828 |
|
829 |
self.mask = Mask(use_post_filter=config.use_post_filter)
|
830 |
|
831 |
+
self.lsnr_fn = LocalSnrTarget(
|
832 |
+
sample_rate=config.sample_rate,
|
833 |
+
nfft=config.nfft,
|
834 |
+
win_size=config.win_size,
|
835 |
+
hop_size=config.hop_size,
|
836 |
+
n_frame=config.n_frame,
|
837 |
+
min_local_snr=config.min_local_snr,
|
838 |
+
max_local_snr=config.max_local_snr,
|
839 |
+
db=True,
|
840 |
+
)
|
841 |
+
|
842 |
+
def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor:
|
843 |
+
if signal.dim() == 2:
|
844 |
+
signal = torch.unsqueeze(signal, dim=1)
|
845 |
+
_, _, n_samples = signal.shape
|
846 |
remainder = (n_samples - self.win_size) % self.hop_size
|
847 |
if remainder > 0:
|
848 |
n_samples_pad = self.hop_size - remainder
|
849 |
+
signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
|
850 |
+
return signal, n_samples
|
851 |
|
852 |
+
def forward(self,
|
853 |
+
noisy: torch.Tensor,
|
854 |
+
):
|
855 |
+
"""
|
856 |
+
:param noisy:
|
857 |
+
:return:
|
858 |
+
est_spec: shape: [b, 257*2, t]
|
859 |
+
est_wav: shape: [b, num_samples]
|
860 |
+
est_mask: shape: [b, 257, t]
|
861 |
+
lsnr: shape: [b, 1, t]
|
862 |
+
"""
|
863 |
+
noisy, n_samples = self.signal_prepare(noisy)
|
864 |
|
865 |
+
# noisy shape: [b, num_samples_pad]
|
866 |
+
cmp_spec = self.stft.forward(noisy)
|
867 |
+
# cmp_spec shape: [b, f, t], torch.complex64
|
868 |
+
cmp_spec = torch.view_as_real(cmp_spec)
|
869 |
+
# cmp_spec shape: [b, f, t, 2]
|
870 |
+
cmp_spec = cmp_spec.permute(0, 3, 1, 2)
|
871 |
+
# cmp_spec shape: [b, 2, f, t]
|
872 |
cmp_spec = cmp_spec[:, :, :-1, :]
|
873 |
+
# cmp_spec shape: [b, 2, spec_bins, t]
|
874 |
+
# n//2+1 -> n//2; 257 -> 256
|
875 |
|
876 |
spec = torch.unsqueeze(cmp_spec, dim=4)
|
877 |
+
# spec shape: [b, 2, spec_bins, t, 1]
|
878 |
spec = spec.permute(0, 4, 3, 2, 1)
|
879 |
+
# spec shape: [b, 1, t, spec_bins, 2]
|
880 |
|
881 |
feat_power = torch.sum(torch.square(spec), dim=-1)
|
882 |
+
# feat_power shape: [b, 1, t, spec_bins]
|
883 |
|
884 |
feat_spec = torch.transpose(cmp_spec, dim0=2, dim1=3)
|
885 |
+
# feat_spec shape: [b, 2, t, spec_bins]
|
886 |
feat_spec = feat_spec[..., :self.df_decoder.df_bins]
|
887 |
+
# feat_spec shape: [b, 2, t, df_bins]
|
888 |
|
889 |
e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
|
890 |
|
891 |
mask = self.decoder.forward(emb, e3, e2, e1, e0)
|
892 |
+
# mask shape: [b, 1, t, spec_bins]
|
893 |
if torch.any(mask > 1) or torch.any(mask < 0):
|
894 |
raise AssertionError
|
895 |
|
896 |
spec_m = self.mask.forward(spec, mask)
|
897 |
|
898 |
+
# lsnr shape: [b, t, 1]
|
899 |
lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
|
900 |
+
# lsnr shape: [b, 1, t]
|
901 |
|
902 |
df_coefs = self.df_decoder.forward(emb, c0)
|
903 |
df_coefs = self.df_out_transform(df_coefs)
|
904 |
+
# df_coefs shape: [b, df_order, t, df_bins, 2]
|
905 |
|
906 |
spec_e = self.df_op.forward(spec.clone(), df_coefs)
|
907 |
+
# est_spec shape: [b, 1, t, spec_bins, 2]
|
908 |
|
909 |
spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
|
910 |
|
911 |
spec_e = torch.squeeze(spec_e, dim=1)
|
912 |
spec_e = spec_e.permute(0, 2, 1, 3)
|
913 |
+
# spec_e shape: [b, spec_bins, t, 2]
|
914 |
|
915 |
mask = torch.squeeze(mask, dim=1)
|
916 |
mask = mask.permute(0, 2, 1)
|
917 |
+
# mask shape: [b, spec_bins, t]
|
918 |
est_mask = self.mask_transfer(mask)
|
919 |
+
# est_mask shape: [b, f, t]
|
920 |
|
921 |
+
# spec_e shape: [b, spec_bins, t, 2]
|
922 |
est_spec = self.spec_transfer(spec_e)
|
923 |
+
# est_spec shape: [b, f, t], torch.complex64
|
924 |
+
|
925 |
est_wav = self.istft.forward(est_spec)
|
926 |
est_wav = torch.squeeze(est_wav, dim=1)
|
927 |
est_wav = est_wav[:, :n_samples]
|
928 |
# est_wav shape: [b, n_samples]
|
929 |
+
|
930 |
return est_spec, est_wav, est_mask, lsnr
|
931 |
|
932 |
def spec_transfer(self, spec_e: torch.Tensor) -> torch.Tensor:
|
933 |
+
# spec_e shape: [b, spec_bins, t, 2]
|
934 |
b, _, t, _ = spec_e.shape
|
935 |
+
est_spec = torch.complex(
|
936 |
+
real=torch.concat(tensors=[
|
937 |
spec_e[..., 0],
|
938 |
torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
|
939 |
], dim=1),
|
940 |
+
imag=torch.concat(tensors=[
|
941 |
spec_e[..., 1],
|
942 |
torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
|
943 |
], dim=1),
|
944 |
+
)
|
945 |
+
# est_spec shape: [b, f, t]
|
946 |
return est_spec
|
947 |
|
948 |
def mask_transfer(self, mask: torch.Tensor) -> torch.Tensor:
|
|
|
957 |
|
958 |
def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
|
959 |
"""
|
960 |
+
:param est_mask: torch.Tensor, shape: [b, 257, t]
|
|
|
961 |
:param clean:
|
962 |
:param noisy:
|
963 |
:return:
|
964 |
"""
|
965 |
+
if noisy.shape != clean.shape:
|
966 |
+
raise AssertionError("Input signals must have the same shape")
|
967 |
+
noise = noisy - clean
|
968 |
+
|
969 |
+
clean, _ = self.signal_prepare(clean)
|
970 |
+
noise, _ = self.signal_prepare(noise)
|
971 |
|
972 |
+
stft_clean = self.stft.forward(clean)
|
973 |
+
mag_clean = torch.abs(stft_clean)
|
|
|
|
|
974 |
|
975 |
+
stft_noise = self.stft.forward(noise)
|
976 |
+
mag_noise = torch.abs(stft_noise)
|
977 |
|
978 |
+
gth_irm_mask = (mag_clean / (mag_clean + mag_noise + self.eps)).clamp(0, 1)
|
979 |
+
|
980 |
+
loss = F.l1_loss(gth_irm_mask, est_mask, reduction="mean")
|
981 |
|
982 |
return loss
|
983 |
|
984 |
+
def lsnr_loss_fn(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
|
985 |
+
if noisy.shape != clean.shape:
|
986 |
+
raise AssertionError("Input signals must have the same shape")
|
987 |
+
noise = noisy - clean
|
988 |
+
|
989 |
+
clean, _ = self.signal_prepare(clean)
|
990 |
+
noise, _ = self.signal_prepare(noise)
|
991 |
+
|
992 |
+
stft_clean = self.stft.forward(clean)
|
993 |
+
stft_noise = self.stft.forward(noise)
|
994 |
+
# shape: [b, f, t]
|
995 |
+
stft_clean = torch.transpose(stft_clean, dim0=1, dim1=2)
|
996 |
+
stft_noise = torch.transpose(stft_noise, dim0=1, dim1=2)
|
997 |
+
# shape: [b, t, f]
|
998 |
+
stft_clean = torch.unsqueeze(stft_clean, dim=1)
|
999 |
+
stft_noise = torch.unsqueeze(stft_noise, dim=1)
|
1000 |
+
# shape: [b, 1, t, f]
|
1001 |
+
|
1002 |
+
# lsnr shape: [b, 1, t]
|
1003 |
+
lsnr = lsnr.squeeze(1)
|
1004 |
+
# lsnr shape: [b, t]
|
1005 |
+
|
1006 |
+
lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
|
1007 |
+
# lsnr_gth shape: [b, t]
|
1008 |
+
|
1009 |
+
loss = F.mse_loss(lsnr, lsnr_gth)
|
1010 |
+
return loss
|
1011 |
+
|
1012 |
|
1013 |
class DfNetPretrainedModel(DfNet):
|
1014 |
def __init__(self,
|
|
|
1063 |
|
1064 |
noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
|
1065 |
|
1066 |
+
est_spec, est_wav, est_mask, lsnr = model.forward(noisy)
|
1067 |
+
print(f"est_spec.shape: {est_spec.shape}")
|
1068 |
+
print(f"est_wav.shape: {est_wav.shape}")
|
1069 |
+
print(f"est_mask.shape: {est_mask.shape}")
|
1070 |
+
print(f"lsnr.shape: {lsnr.shape}")
|
1071 |
+
|
1072 |
return
|
1073 |
|
1074 |
|
toolbox/torchaudio/modules/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/modules/local_snr_target.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/modules.py#L816
|
5 |
+
"""
|
6 |
+
from typing import Tuple
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
import torchaudio
|
12 |
+
|
13 |
+
|
14 |
+
def local_energy(spec: torch.Tensor, n_frame: int, device: torch.device) -> torch.Tensor:
|
15 |
+
if n_frame % 2 == 0:
|
16 |
+
n_frame += 1
|
17 |
+
n_frame_half = n_frame // 2
|
18 |
+
|
19 |
+
# spec shape: [b, c, t, f, 2]
|
20 |
+
spec = F.pad(spec.pow(2).sum(-1).sum(-1), (n_frame_half, n_frame_half, 0, 0))
|
21 |
+
# spec shape: [b, c, t-pad]
|
22 |
+
|
23 |
+
weight = torch.hann_window(n_frame, device=device, dtype=spec.dtype)
|
24 |
+
# w shape: [n_frame]
|
25 |
+
|
26 |
+
spec = spec.unfold(-1, size=n_frame, step=1) * weight
|
27 |
+
# x shape: [b, c, t, n_frame]
|
28 |
+
|
29 |
+
result = torch.sum(spec, dim=-1).div(n_frame)
|
30 |
+
# result shape: [b, c, t]
|
31 |
+
return result
|
32 |
+
|
33 |
+
|
34 |
+
def local_snr(spec_clean: torch.Tensor,
|
35 |
+
spec_noise: torch.Tensor,
|
36 |
+
n_frame: int = 5,
|
37 |
+
db: bool = False,
|
38 |
+
eps: float = 1e-12,
|
39 |
+
):
|
40 |
+
# [b, c, t, f]
|
41 |
+
spec_clean = torch.view_as_real(spec_clean)
|
42 |
+
spec_noise = torch.view_as_real(spec_noise)
|
43 |
+
# [b, c, t, f, 2]
|
44 |
+
|
45 |
+
energy_clean = local_energy(spec_clean, n_frame=n_frame, device=spec_clean.device)
|
46 |
+
energy_noise = local_energy(spec_noise, n_frame=n_frame, device=spec_noise.device)
|
47 |
+
# [b, c, t]
|
48 |
+
|
49 |
+
snr = energy_clean / energy_noise.clamp_min(eps)
|
50 |
+
# snr shape: [b, c, t]
|
51 |
+
|
52 |
+
if db:
|
53 |
+
snr = snr.clamp_min(eps).log10().mul(10)
|
54 |
+
return snr, energy_clean, energy_noise
|
55 |
+
|
56 |
+
|
57 |
+
class LocalSnrTarget(nn.Module):
|
58 |
+
def __init__(self,
|
59 |
+
sample_rate: int = 8000,
|
60 |
+
nfft: int = 512,
|
61 |
+
win_size: int = 512,
|
62 |
+
hop_size: int = 256,
|
63 |
+
|
64 |
+
n_frame: int = 3,
|
65 |
+
|
66 |
+
min_local_snr: int = -15,
|
67 |
+
max_local_snr: int = 30,
|
68 |
+
|
69 |
+
db: bool = True,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
self.sample_rate = sample_rate
|
73 |
+
self.nfft = nfft
|
74 |
+
self.win_size = win_size
|
75 |
+
self.hop_size = hop_size
|
76 |
+
|
77 |
+
self.n_frame = n_frame
|
78 |
+
|
79 |
+
self.min_local_snr = min_local_snr
|
80 |
+
self.max_local_snr = max_local_snr
|
81 |
+
|
82 |
+
self.db = db
|
83 |
+
|
84 |
+
def forward(self,
|
85 |
+
spec_clean: torch.Tensor,
|
86 |
+
spec_noise: torch.Tensor,
|
87 |
+
) -> torch.Tensor:
|
88 |
+
"""
|
89 |
+
|
90 |
+
:param spec_clean: torch.complex, shape: [b, c, t, f]
|
91 |
+
:param spec_noise: torch.complex, shape: [b, c, t, f]
|
92 |
+
:return: lsnr, shape: [b, t]
|
93 |
+
"""
|
94 |
+
|
95 |
+
lsnr, _, _ = local_snr(
|
96 |
+
spec_clean=spec_clean,
|
97 |
+
spec_noise=spec_noise,
|
98 |
+
n_frame=self.n_frame,
|
99 |
+
db=self.db,
|
100 |
+
)
|
101 |
+
# lsnr shape: [b, c, t]
|
102 |
+
lsnr = lsnr.clamp(self.min_local_snr, self.max_local_snr).squeeze(1)
|
103 |
+
# lsnr shape: [b, t]
|
104 |
+
return lsnr
|
105 |
+
|
106 |
+
|
107 |
+
def main():
|
108 |
+
sample_rate = 8000
|
109 |
+
nfft = 512
|
110 |
+
win_size = 512
|
111 |
+
hop_size = 256
|
112 |
+
window_fn = "hamming"
|
113 |
+
|
114 |
+
transform = torchaudio.transforms.Spectrogram(
|
115 |
+
n_fft=nfft,
|
116 |
+
win_length=win_size,
|
117 |
+
hop_length=hop_size,
|
118 |
+
power=None,
|
119 |
+
window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
|
120 |
+
)
|
121 |
+
|
122 |
+
noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
|
123 |
+
|
124 |
+
spec = transform.forward(noisy)
|
125 |
+
spec = spec.permute(0, 2, 1)
|
126 |
+
spec = torch.unsqueeze(spec, dim=1)
|
127 |
+
print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
|
128 |
+
|
129 |
+
# [b, c, t, f]
|
130 |
+
# spec = torch.view_as_real(spec)
|
131 |
+
# [b, c, t, f, 2]
|
132 |
+
|
133 |
+
local = LocalSnrTarget(
|
134 |
+
sample_rate=sample_rate,
|
135 |
+
nfft=nfft,
|
136 |
+
win_size=win_size,
|
137 |
+
hop_size=hop_size,
|
138 |
+
n_frame=5,
|
139 |
+
min_local_snr=-15,
|
140 |
+
max_local_snr=30,
|
141 |
+
db=True,
|
142 |
+
)
|
143 |
+
lsnr_target = local.forward(spec, spec)
|
144 |
+
print(f"lsnr_target.shape: {lsnr_target.shape}")
|
145 |
+
return
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
main()
|