nx_denoise / toolbox /torchaudio /modules /local_snr_target.py
HoneyTian's picture
update
94ba8b5
raw
history blame
3.97 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/modules.py#L816
"""
from typing import Tuple
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchaudio
def local_energy(spec: torch.Tensor, n_frame: int, device: torch.device) -> torch.Tensor:
if n_frame % 2 == 0:
n_frame += 1
n_frame_half = n_frame // 2
# spec shape: [b, c, t, f, 2]
spec = F.pad(spec.pow(2).sum(-1).sum(-1), (n_frame_half, n_frame_half, 0, 0))
# spec shape: [b, c, t-pad]
weight = torch.hann_window(n_frame, device=device, dtype=spec.dtype)
# w shape: [n_frame]
spec = spec.unfold(-1, size=n_frame, step=1) * weight
# x shape: [b, c, t, n_frame]
result = torch.sum(spec, dim=-1).div(n_frame)
# result shape: [b, c, t]
return result
def local_snr(spec_clean: torch.Tensor,
spec_noise: torch.Tensor,
n_frame: int = 5,
db: bool = False,
eps: float = 1e-12,
):
# [b, c, t, f]
spec_clean = torch.view_as_real(spec_clean)
spec_noise = torch.view_as_real(spec_noise)
# [b, c, t, f, 2]
energy_clean = local_energy(spec_clean, n_frame=n_frame, device=spec_clean.device)
energy_noise = local_energy(spec_noise, n_frame=n_frame, device=spec_noise.device)
# [b, c, t]
snr = energy_clean / energy_noise.clamp_min(eps)
# snr shape: [b, c, t]
if db:
snr = snr.clamp_min(eps).log10().mul(10)
return snr, energy_clean, energy_noise
class LocalSnrTarget(nn.Module):
def __init__(self,
sample_rate: int = 8000,
nfft: int = 512,
win_size: int = 512,
hop_size: int = 256,
n_frame: int = 3,
min_local_snr: int = -15,
max_local_snr: int = 30,
db: bool = True,
):
super().__init__()
self.sample_rate = sample_rate
self.nfft = nfft
self.win_size = win_size
self.hop_size = hop_size
self.n_frame = n_frame
self.min_local_snr = min_local_snr
self.max_local_snr = max_local_snr
self.db = db
def forward(self,
spec_clean: torch.Tensor,
spec_noise: torch.Tensor,
) -> torch.Tensor:
"""
:param spec_clean: torch.complex, shape: [b, c, t, f]
:param spec_noise: torch.complex, shape: [b, c, t, f]
:return: lsnr, shape: [b, t]
"""
lsnr, _, _ = local_snr(
spec_clean=spec_clean,
spec_noise=spec_noise,
n_frame=self.n_frame,
db=self.db,
)
# lsnr shape: [b, c, t]
lsnr = lsnr.clamp(self.min_local_snr, self.max_local_snr).squeeze(1)
# lsnr shape: [b, t]
return lsnr
def main():
sample_rate = 8000
nfft = 512
win_size = 512
hop_size = 256
window_fn = "hamming"
transform = torchaudio.transforms.Spectrogram(
n_fft=nfft,
win_length=win_size,
hop_length=hop_size,
power=None,
window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
)
noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
spec = transform.forward(noisy)
spec = spec.permute(0, 2, 1)
spec = torch.unsqueeze(spec, dim=1)
print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
# [b, c, t, f]
# spec = torch.view_as_real(spec)
# [b, c, t, f, 2]
local = LocalSnrTarget(
sample_rate=sample_rate,
nfft=nfft,
win_size=win_size,
hop_size=hop_size,
n_frame=5,
min_local_snr=-15,
max_local_snr=30,
db=True,
)
lsnr_target = local.forward(spec, spec)
print(f"lsnr_target.shape: {lsnr_target.shape}")
return
if __name__ == "__main__":
main()