HoneyTian commited on
Commit
efe955e
·
1 Parent(s): 1042eee
toolbox/torchaudio/models/clean_unet/loss.py CHANGED
@@ -4,12 +4,6 @@ import torch
4
  import torch
5
  import torch.nn.functional as F
6
 
7
- # from distutils.version import LooseVersion
8
-
9
-
10
- # is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion("1.7")
11
- is_pytorch_17plus = True
12
-
13
 
14
  def stft(x, fft_size, hop_size, win_length, window):
15
  """
@@ -22,18 +16,9 @@ def stft(x, fft_size, hop_size, win_length, window):
22
  :return: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
23
  """
24
 
25
- if is_pytorch_17plus:
26
- x_stft = torch.stft(
27
- x, fft_size, hop_size, win_length, window, return_complex=False
28
- )
29
- else:
30
- x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
31
- real = x_stft[..., 0]
32
- imag = x_stft[..., 1]
33
-
34
- # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
35
- return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
36
 
 
37
 
38
  class SpectralConvergenceLoss(torch.nn.Module):
39
  """Spectral convergence loss module."""
 
4
  import torch
5
  import torch.nn.functional as F
6
 
 
 
 
 
 
 
7
 
8
  def stft(x, fft_size, hop_size, win_length, window):
9
  """
 
16
  :return: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
17
  """
18
 
19
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
 
 
 
 
 
 
 
 
 
 
20
 
21
+ return x_stft.abs()
22
 
23
  class SpectralConvergenceLoss(torch.nn.Module):
24
  """Spectral convergence loss module."""