Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
""" | |
https://github.com/Rikorose/DeepFilterNet/blob/main/DeepFilterNet/df/modules.py#L816 | |
""" | |
from typing import Tuple | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
import torchaudio | |
def local_energy(spec: torch.Tensor, n_frame: int, device: torch.device) -> torch.Tensor: | |
if n_frame % 2 == 0: | |
n_frame += 1 | |
n_frame_half = n_frame // 2 | |
# spec shape: [b, c, t, f, 2] | |
spec = F.pad(spec.pow(2).sum(-1).sum(-1), (n_frame_half, n_frame_half, 0, 0)) | |
# spec shape: [b, c, t-pad] | |
weight = torch.hann_window(n_frame, device=device, dtype=spec.dtype) | |
# w shape: [n_frame] | |
spec = spec.unfold(-1, size=n_frame, step=1) * weight | |
# x shape: [b, c, t, n_frame] | |
result = torch.sum(spec, dim=-1).div(n_frame) | |
# result shape: [b, c, t] | |
return result | |
def local_snr(spec_clean: torch.Tensor, | |
spec_noise: torch.Tensor, | |
n_frame: int = 5, | |
db: bool = False, | |
eps: float = 1e-12, | |
): | |
# [b, c, t, f] | |
spec_clean = torch.view_as_real(spec_clean) | |
spec_noise = torch.view_as_real(spec_noise) | |
# [b, c, t, f, 2] | |
energy_clean = local_energy(spec_clean, n_frame=n_frame, device=spec_clean.device) | |
energy_noise = local_energy(spec_noise, n_frame=n_frame, device=spec_noise.device) | |
# [b, c, t] | |
snr = energy_clean / energy_noise.clamp_min(eps) | |
# snr shape: [b, c, t] | |
if db: | |
snr = snr.clamp_min(eps).log10().mul(10) | |
return snr, energy_clean, energy_noise | |
class LocalSnrTarget(nn.Module): | |
def __init__(self, | |
sample_rate: int = 8000, | |
nfft: int = 512, | |
win_size: int = 512, | |
hop_size: int = 256, | |
n_frame: int = 3, | |
min_local_snr: int = -15, | |
max_local_snr: int = 30, | |
db: bool = True, | |
): | |
super().__init__() | |
self.sample_rate = sample_rate | |
self.nfft = nfft | |
self.win_size = win_size | |
self.hop_size = hop_size | |
self.n_frame = n_frame | |
self.min_local_snr = min_local_snr | |
self.max_local_snr = max_local_snr | |
self.db = db | |
def forward(self, | |
spec_clean: torch.Tensor, | |
spec_noise: torch.Tensor, | |
) -> torch.Tensor: | |
""" | |
:param spec_clean: torch.complex, shape: [b, c, t, f] | |
:param spec_noise: torch.complex, shape: [b, c, t, f] | |
:return: lsnr, shape: [b, t] | |
""" | |
lsnr, _, _ = local_snr( | |
spec_clean=spec_clean, | |
spec_noise=spec_noise, | |
n_frame=self.n_frame, | |
db=self.db, | |
) | |
# lsnr shape: [b, c, t] | |
lsnr = lsnr.clamp(self.min_local_snr, self.max_local_snr).squeeze(1) | |
# lsnr shape: [b, t] | |
return lsnr | |
def main(): | |
sample_rate = 8000 | |
nfft = 512 | |
win_size = 512 | |
hop_size = 256 | |
window_fn = "hamming" | |
transform = torchaudio.transforms.Spectrogram( | |
n_fft=nfft, | |
win_length=win_size, | |
hop_length=hop_size, | |
power=None, | |
window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, | |
) | |
noisy = torch.randn(size=(1, 16000), dtype=torch.float32) | |
spec = transform.forward(noisy) | |
spec = spec.permute(0, 2, 1) | |
spec = torch.unsqueeze(spec, dim=1) | |
print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}") | |
# [b, c, t, f] | |
# spec = torch.view_as_real(spec) | |
# [b, c, t, f, 2] | |
local = LocalSnrTarget( | |
sample_rate=sample_rate, | |
nfft=nfft, | |
win_size=win_size, | |
hop_size=hop_size, | |
n_frame=5, | |
min_local_snr=-15, | |
max_local_snr=30, | |
db=True, | |
) | |
lsnr_target = local.forward(spec, spec) | |
print(f"lsnr_target.shape: {lsnr_target.shape}") | |
return | |
if __name__ == "__main__": | |
main() | |