Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/clean_unet/loss.py
CHANGED
@@ -4,12 +4,6 @@ import torch
|
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
|
7 |
-
# from distutils.version import LooseVersion
|
8 |
-
|
9 |
-
|
10 |
-
# is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion("1.7")
|
11 |
-
is_pytorch_17plus = True
|
12 |
-
|
13 |
|
14 |
def stft(x, fft_size, hop_size, win_length, window):
|
15 |
"""
|
@@ -22,18 +16,9 @@ def stft(x, fft_size, hop_size, win_length, window):
|
|
22 |
:return: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
23 |
"""
|
24 |
|
25 |
-
|
26 |
-
x_stft = torch.stft(
|
27 |
-
x, fft_size, hop_size, win_length, window, return_complex=False
|
28 |
-
)
|
29 |
-
else:
|
30 |
-
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
|
31 |
-
real = x_stft[..., 0]
|
32 |
-
imag = x_stft[..., 1]
|
33 |
-
|
34 |
-
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
35 |
-
return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
|
36 |
|
|
|
37 |
|
38 |
class SpectralConvergenceLoss(torch.nn.Module):
|
39 |
"""Spectral convergence loss module."""
|
|
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
def stft(x, fft_size, hop_size, win_length, window):
|
9 |
"""
|
|
|
16 |
:return: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
17 |
"""
|
18 |
|
19 |
+
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
return x_stft.abs()
|
22 |
|
23 |
class SpectralConvergenceLoss(torch.nn.Module):
|
24 |
"""Spectral convergence loss module."""
|