HoneyTian commited on
Commit
69fa971
·
1 Parent(s): d9a2a24
examples/dfnet/yaml/config-512.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "dfnet"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+
9
+ spec_bins: 256
10
+
11
+ # model
12
+ conv_channels: 64
13
+ conv_kernel_size_input:
14
+ - 3
15
+ - 3
16
+ conv_kernel_size_inner:
17
+ - 1
18
+ - 3
19
+ conv_lookahead: 0
20
+
21
+ convt_kernel_size_inner:
22
+ - 1
23
+ - 3
24
+
25
+ embedding_hidden_size: 256
26
+ encoder_combine_op: "concat"
27
+
28
+ encoder_emb_skip_op: "none"
29
+ encoder_emb_linear_groups: 16
30
+ encoder_emb_hidden_size: 256
31
+
32
+ encoder_linear_groups: 32
33
+
34
+ decoder_emb_num_layers: 3
35
+ decoder_emb_skip_op: "none"
36
+ decoder_emb_linear_groups: 16
37
+ decoder_emb_hidden_size: 256
38
+
39
+ df_decoder_hidden_size: 256
40
+ df_num_layers: 2
41
+ df_order: 5
42
+ df_bins: 96
43
+ df_gru_skip: "grouped_linear"
44
+ df_decoder_linear_groups: 16
45
+ df_pathway_kernel_size_t: 5
46
+ df_lookahead: 2
47
+
48
+ # lsnr
49
+ n_frame: 3
50
+ lsnr_max: 30
51
+ lsnr_min: -15
52
+ norm_tau: 1.
53
+
54
+ # data
55
+ min_snr_db: -10
56
+ max_snr_db: 20
57
+
58
+ # train
59
+ lr: 0.001
60
+ lr_scheduler: "CosineAnnealingLR"
61
+ lr_scheduler_kwargs:
62
+ T_max: 250000
63
+ eta_min: 0.0001
64
+
65
+ max_epochs: 100
66
+ clip_grad_norm: 10.0
67
+ seed: 1234
68
+
69
+ num_workers: 8
70
+ batch_size: 32
71
+ eval_steps: 10000
72
+
73
+ # runtime
74
+ use_post_filter: true
examples/dfnet/yaml/config.yaml CHANGED
@@ -2,14 +2,14 @@ model_name: "dfnet"
2
 
3
  # spec
4
  sample_rate: 8000
5
- n_fft: 512
6
- win_length: 200
7
  hop_length: 80
8
 
9
- spec_bins: 256
10
 
11
  # model
12
- conv_channels: 64
13
  conv_kernel_size_input:
14
  - 3
15
  - 3
@@ -22,26 +22,26 @@ convt_kernel_size_inner:
22
  - 1
23
  - 3
24
 
25
- embedding_hidden_size: 256
26
  encoder_combine_op: "concat"
27
 
28
  encoder_emb_skip_op: "none"
29
- encoder_emb_linear_groups: 16
30
- encoder_emb_hidden_size: 256
31
 
32
- encoder_linear_groups: 32
33
 
34
  decoder_emb_num_layers: 3
35
  decoder_emb_skip_op: "none"
36
- decoder_emb_linear_groups: 16
37
- decoder_emb_hidden_size: 256
38
 
39
- df_decoder_hidden_size: 256
40
  df_num_layers: 2
41
  df_order: 5
42
- df_bins: 96
43
  df_gru_skip: "grouped_linear"
44
- df_decoder_linear_groups: 16
45
  df_pathway_kernel_size_t: 5
46
  df_lookahead: 2
47
 
 
2
 
3
  # spec
4
  sample_rate: 8000
5
+ n_fft: 160
6
+ win_length: 160
7
  hop_length: 80
8
 
9
+ spec_bins: 80
10
 
11
  # model
12
+ conv_channels: 32
13
  conv_kernel_size_input:
14
  - 3
15
  - 3
 
22
  - 1
23
  - 3
24
 
25
+ embedding_hidden_size: 80
26
  encoder_combine_op: "concat"
27
 
28
  encoder_emb_skip_op: "none"
29
+ encoder_emb_linear_groups: 5
30
+ encoder_emb_hidden_size: 80
31
 
32
+ encoder_linear_groups: 10
33
 
34
  decoder_emb_num_layers: 3
35
  decoder_emb_skip_op: "none"
36
+ decoder_emb_linear_groups: 5
37
+ decoder_emb_hidden_size: 80
38
 
39
+ df_decoder_hidden_size: 80
40
  df_num_layers: 2
41
  df_order: 5
42
+ df_bins: 30
43
  df_gru_skip: "grouped_linear"
44
+ df_decoder_linear_groups: 5
45
  df_pathway_kernel_size_t: 5
46
  df_lookahead: 2
47
 
main.py CHANGED
@@ -61,49 +61,22 @@ def shell(cmd: str):
61
 
62
 
63
  denoise_engines = {
64
- "mpnet-nx-speech-1-epoch": {
65
  "infer_cls": InferenceMPNet,
66
  "kwargs": {
67
- "pretrained_model_path_or_zip_file": (
68
- project_path / "trained_models/mpnet-nx-speech-1-epoch.zip").as_posix()
69
- }
70
- },
71
- "mpnet-nx-speech-20-epoch": {
72
- "infer_cls": InferenceMPNet,
73
- "kwargs": {
74
- "pretrained_model_path_or_zip_file": (
75
- project_path / "trained_models/mpnet-nx-speech-20-epoch.zip").as_posix()
76
- }
77
- },
78
- "mpnet-nx-speech-33-epoch-best": {
79
- "infer_cls": InferenceMPNet,
80
- "kwargs": {
81
- "pretrained_model_path_or_zip_file": (
82
- project_path / "trained_models/mpnet-nx-speech-33-epoch-best.zip").as_posix()
83
- }
84
- },
85
- "mpnet-aishell-1-epoch": {
86
- "infer_cls": InferenceMPNet,
87
- "kwargs": {
88
- "pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-aishell-1-epoch.zip").as_posix()
89
- }
90
- },
91
- "mpnet-aishell-11-epoch": {
92
- "infer_cls": InferenceMPNet,
93
- "kwargs": {
94
- "pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-aishell-11-epoch.zip").as_posix()
95
  }
96
  },
97
  "frcrn-dns3": {
98
  "infer_cls": InferenceFRCRN,
99
  "kwargs": {
100
- "pretrained_model_path_or_zip_file": (project_path / "trained_models/frcrn-dns3-220k-steps.zip").as_posix()
101
  }
102
  },
103
  }
104
 
105
 
106
- @lru_cache(maxsize=3)
107
  def load_denoise_model(infer_cls, **kwargs):
108
  infer_engine = infer_cls(**kwargs)
109
 
 
61
 
62
 
63
  denoise_engines = {
64
+ "mpnet-nx-speech": {
65
  "infer_cls": InferenceMPNet,
66
  "kwargs": {
67
+ "pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-nx-speech.zip").as_posix()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  }
69
  },
70
  "frcrn-dns3": {
71
  "infer_cls": InferenceFRCRN,
72
  "kwargs": {
73
+ "pretrained_model_path_or_zip_file": (project_path / "trained_models/frcrn-dns3.zip").as_posix()
74
  }
75
  },
76
  }
77
 
78
 
79
+ @lru_cache(maxsize=1)
80
  def load_denoise_model(infer_cls, **kwargs):
81
  infer_engine = infer_cls(**kwargs)
82
 
toolbox/torchaudio/models/dfnet/modeling_dfnet.py CHANGED
@@ -215,7 +215,10 @@ class GroupedLinear(nn.Module):
215
 
216
  def forward(self, x: torch.Tensor) -> torch.Tensor:
217
  # x: [..., I]
218
- b, t, _ = x.shape
 
 
 
219
  # new_shape = list(x.shape)[:-1] + [self.groups, self.ws]
220
  new_shape = (b, t, self.groups, self.ws)
221
  x = x.view(new_shape)
@@ -633,8 +636,9 @@ class DfDecoder(nn.Module):
633
  GroupedLinear(
634
  input_size=self.df_decoder_hidden_size,
635
  hidden_size=out_dim,
636
- groups=config.df_decoder_linear_groups
637
- ),
 
638
  nn.Tanh()
639
  )
640
  self.df_fc_a = nn.Sequential(
 
215
 
216
  def forward(self, x: torch.Tensor) -> torch.Tensor:
217
  # x: [..., I]
218
+ b, t, f = x.shape
219
+ if f != self.input_size:
220
+ raise AssertionError
221
+
222
  # new_shape = list(x.shape)[:-1] + [self.groups, self.ws]
223
  new_shape = (b, t, self.groups, self.ws)
224
  x = x.view(new_shape)
 
636
  GroupedLinear(
637
  input_size=self.df_decoder_hidden_size,
638
  hidden_size=out_dim,
639
+ groups=config.df_decoder_linear_groups,
640
+ # groups = self.df_bins // 5,
641
+ ),
642
  nn.Tanh()
643
  )
644
  self.df_fc_a = nn.Sequential(
toolbox/torchaudio/modules/erb_bands.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import math
4
+
5
+ import numpy as np
6
+
7
+
8
+ def freq2erb(freq_hz: float) -> float:
9
+ """
10
+ https://www.cnblogs.com/LXP-Never/p/16011229.html
11
+ 1 / (24.7 * 9.265) = 0.00436976
12
+ """
13
+ return 9.265 * math.log(freq_hz / (24.7 * 9.265) + 1)
14
+
15
+
16
+ def erb2freq(n_erb: float) -> float:
17
+ return 24.7 * 9.265 * (math.exp(n_erb / 9.265) - 1)
18
+
19
+
20
+ def get_erb_widths(sample_rate: int, fft_size: int, erb_bins: int, min_freq_bins_for_erb: int) -> np.ndarray:
21
+ """
22
+ https://github.com/Rikorose/DeepFilterNet/blob/main/libDF/src/lib.rs
23
+ :param sample_rate:
24
+ :param fft_size:
25
+ :param erb_bins: erb (Equivalent Rectangular Bandwidth) 等效矩形带宽的通道数.
26
+ :param min_freq_bins_for_erb: Minimum number of frequency bands per erb band
27
+ :return:
28
+ """
29
+ nyq_freq = sample_rate / 2.
30
+ freq_width: float = sample_rate / fft_size
31
+
32
+ min_erb: float = freq2erb(0.)
33
+ max_erb: float = freq2erb(nyq_freq)
34
+
35
+ erb = [0] * erb_bins
36
+ step = (max_erb - min_erb) / erb_bins
37
+
38
+ prev_freq_bin = 0
39
+ freq_over = 0
40
+ for i in range(1, erb_bins + 1):
41
+ f = erb2freq(min_erb + i * step)
42
+ freq_bin = int(round(f / freq_width))
43
+ freq_bins = freq_bin - prev_freq_bin - freq_over
44
+
45
+ if freq_bins < min_freq_bins_for_erb:
46
+ freq_over = min_freq_bins_for_erb - freq_bins
47
+ freq_bins = min_freq_bins_for_erb
48
+ else:
49
+ freq_over = 0
50
+ erb[i - 1] = freq_bins
51
+ prev_freq_bin = freq_bin
52
+
53
+ erb[erb_bins - 1] += 1
54
+ too_large = sum(erb) - (fft_size / 2 + 1)
55
+ if too_large > 0:
56
+ erb[erb_bins - 1] -= too_large
57
+ return np.array(erb, dtype=np.uint64)
58
+
59
+
60
+ def get_erb_filter_bank(erb_widths: np.ndarray,
61
+ sample_rate: int,
62
+ normalized: bool = True,
63
+ inverse: bool = False,
64
+ ):
65
+ num_freq_bins = int(np.sum(erb_widths))
66
+ num_erb_bins = len(erb_widths)
67
+
68
+ fb: np.ndarray = np.zeros(shape=(num_freq_bins, num_erb_bins))
69
+
70
+ points = np.cumsum([0] + erb_widths.tolist()).astype(int)[:-1]
71
+ for i, (b, w) in enumerate(zip(points.tolist(), erb_widths.tolist())):
72
+ fb[b: b + w, i] = 1
73
+
74
+ if inverse:
75
+ fb = fb.T
76
+ if not normalized:
77
+ fb /= np.sum(fb, axis=1, keepdims=True)
78
+ else:
79
+ if normalized:
80
+ fb /= np.sum(fb, axis=0)
81
+ return fb
82
+
83
+
84
+ def spec2erb(spec: np.ndarray, erb_fb: np.ndarray, db: bool = True):
85
+ """
86
+ ERB filterbank and transform to decibel scale.
87
+
88
+ :param spec: Spectrum of shape [B, C, T, F].
89
+ :param erb_fb: ERB filterbank array of shape [B] containing the ERB widths,
90
+ where B are the number of ERB bins.
91
+ :param db: Whether to transform the output into decibel scale. Defaults to `True`.
92
+ :return:
93
+ """
94
+ # complex spec to power spec. (real * real + image * image)
95
+ spec_ = np.abs(spec) ** 2
96
+
97
+ # spec to erb feature.
98
+ erb_feat = np.matmul(spec_, erb_fb)
99
+
100
+ if db:
101
+ erb_feat = 10 * np.log10(erb_feat + 1e-10)
102
+
103
+ erb_feat = np.array(erb_feat, dtype=np.float32)
104
+ return erb_feat
105
+
106
+
107
+ def main():
108
+ erb_widths = get_erb_widths(
109
+ sample_rate=8000,
110
+ fft_size=512,
111
+ erb_bins=32,
112
+ min_freq_bins_for_erb=2,
113
+ )
114
+ erb_fb = get_erb_filter_bank(
115
+ erb_widths=erb_widths,
116
+ sample_rate=8000,
117
+ )
118
+ print(erb_fb.shape)
119
+
120
+ return
121
+
122
+
123
+ if __name__ == "__main__":
124
+ main()