HoneyTian commited on
Commit
2c570c3
·
1 Parent(s): f86fc1a
examples/nx_clean_unet/yaml/config.yaml CHANGED
@@ -12,6 +12,14 @@ down_sampling_hidden_channels: 64
12
  down_sampling_kernel_size: 4
13
  down_sampling_stride: 2
14
 
 
 
 
 
 
 
 
 
15
  tsfm_hidden_size: 256
16
  tsfm_attention_heads: 8
17
  tsfm_num_blocks: 6
 
12
  down_sampling_kernel_size: 4
13
  down_sampling_stride: 2
14
 
15
+ causal_in_channels: 64
16
+ causal_out_channels: 64
17
+ causal_kernel_size: 3
18
+ causal_bias: false
19
+ causal_separable: true
20
+ causal_f_stride: 1
21
+ causal_num_layers: 3
22
+
23
  tsfm_hidden_size: 256
24
  tsfm_attention_heads: 8
25
  tsfm_num_blocks: 6
toolbox/torchaudio/models/nx_clean_unet/causal_convolution/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/models/nx_clean_unet/causal_convolution/causal_conv2d.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import math
4
+ import os
5
+ from typing import List, Optional, Union, Iterable
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+
12
+
13
+ norm_layer_dict = {
14
+ "batch_norm_2d": torch.nn.BatchNorm2d
15
+ }
16
+
17
+
18
+ activation_layer_dict = {
19
+ "relu": torch.nn.ReLU,
20
+ "identity": torch.nn.Identity,
21
+ "sigmoid": torch.nn.Sigmoid,
22
+ }
23
+
24
+
25
+ class CausalConv2d(nn.Module):
26
+ def __init__(self,
27
+ in_channels: int,
28
+ out_channels: int,
29
+ kernel_size: Union[int, Iterable[int]],
30
+ f_stride: int = 1,
31
+ dilation: int = 1,
32
+ do_f_pad: bool = True,
33
+ bias: bool = True,
34
+ separable: bool = False,
35
+ norm_layer: str = "batch_norm_2d",
36
+ activation_layer: str = "relu",
37
+ lookahead: int = 0
38
+ ):
39
+ super(CausalConv2d, self).__init__()
40
+ kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
41
+
42
+ if do_f_pad:
43
+ f_pad = kernel_size[1] // 2 + dilation - 1
44
+ else:
45
+ f_pad = 0
46
+
47
+ self.causal_left_pad = kernel_size[0] - 1 - lookahead
48
+ self.causal_right_pad = lookahead
49
+ self.constant_pad = nn.ConstantPad2d(
50
+ padding=(0, 0, self.causal_left_pad, self.causal_right_pad),
51
+ value=0.0
52
+ )
53
+
54
+ groups = math.gcd(in_channels, out_channels) if separable else 1
55
+ self.conv1 = nn.Conv2d(
56
+ in_channels,
57
+ out_channels,
58
+ kernel_size=kernel_size,
59
+ padding=(0, f_pad),
60
+ stride=(1, f_stride),
61
+ dilation=(1, dilation),
62
+ groups=groups,
63
+ bias=bias,
64
+ )
65
+
66
+ self.conv2 = None
67
+ if not any([groups == 1, max(kernel_size) == 1]):
68
+ self.conv2 = nn.Conv2d(
69
+ out_channels,
70
+ out_channels,
71
+ kernel_size=1,
72
+ bias=False,
73
+ )
74
+
75
+ self.norm = None
76
+ if norm_layer is not None:
77
+ norm_layer = norm_layer_dict[norm_layer]
78
+ self.norm = norm_layer(out_channels)
79
+
80
+ self.activation = None
81
+ if activation_layer is not None:
82
+ activation_layer = activation_layer_dict[activation_layer]
83
+ self.activation = activation_layer()
84
+
85
+ def forward(self,
86
+ inputs: torch.Tensor,
87
+ causal_cache: torch.Tensor = None,
88
+ ):
89
+
90
+ if causal_cache is None:
91
+ # inputs shape: [batch_size, 1, time_steps, hidden_size]
92
+ x = self.constant_pad.forward(inputs)
93
+ else:
94
+ # inputs shape: [batch_size, 1, time_steps + self.causal_right_pad, hidden_size]
95
+ # causal_cache shape: [batch_size, 1, self.causal_left_pad, hidden_size]
96
+ x = torch.concat(tensors=[causal_cache, inputs], dim=2)
97
+ # x shape: [batch_size, 1, time_steps2, hidden_size]
98
+ # time_steps2 = time_steps + self.causal_left_pad + self.causal_right_pad
99
+
100
+ x = self.conv1.forward(x)
101
+ # inputs shape: [batch_size, 1, time_steps, hidden_size]
102
+
103
+ if self.conv2:
104
+ x = self.conv2.forward(x)
105
+
106
+ if self.norm:
107
+ x = self.norm(x)
108
+ if self.activation:
109
+ x = self.activation(x)
110
+
111
+ causal_cache = x[:, :, -self.causal_left_pad:, :]
112
+
113
+ # inputs shape: [batch_size, 1, time_steps, hidden_size]
114
+ return x, causal_cache
115
+
116
+
117
+ class CausalConv2dEncoder(nn.Module):
118
+ def __init__(self,
119
+ in_channels: int,
120
+ out_channels: int,
121
+ kernel_size: Union[int, Iterable[int]],
122
+ f_stride: int = 1,
123
+ dilation: int = 1,
124
+ do_f_pad: bool = True,
125
+ bias: bool = True,
126
+ separable: bool = False,
127
+ norm_layer: str = "batch_norm_2d",
128
+ activation_layer: str = "relu",
129
+ lookahead: int = 0,
130
+ num_layers: int = 5,
131
+ ):
132
+ super(CausalConv2dEncoder, self).__init__()
133
+ self.num_layers = num_layers
134
+
135
+ self.total_causal_left_pad = 0
136
+ self.total_causal_right_pad = 0
137
+
138
+ self.causal_conv_list: List[CausalConv2d] = nn.ModuleList(modules=[])
139
+ for i_layer in range(num_layers):
140
+ conv = CausalConv2d(
141
+ in_channels=in_channels,
142
+ out_channels=out_channels,
143
+ kernel_size=kernel_size,
144
+ f_stride=f_stride,
145
+ dilation=dilation,
146
+ do_f_pad=do_f_pad,
147
+ bias=bias,
148
+ separable=separable,
149
+ norm_layer=norm_layer,
150
+ activation_layer=activation_layer,
151
+ lookahead=lookahead,
152
+ )
153
+ self.causal_conv_list.append(conv)
154
+
155
+ self.total_causal_left_pad += conv.causal_left_pad
156
+ self.total_causal_right_pad += conv.causal_right_pad
157
+
158
+ in_channels = out_channels
159
+
160
+ def forward(self, inputs: torch.Tensor):
161
+ # inputs shape: [batch_size, 1, time_steps, hidden_size]
162
+
163
+ x = inputs
164
+ for layer in self.causal_conv_list:
165
+ x, _ = layer.forward(x)
166
+ return x
167
+
168
+ def forward_chunk(self,
169
+ chunk: torch.Tensor,
170
+ causal_cache: torch.Tensor = None,
171
+ ):
172
+ # causal_cache shape: [self.num_layers, 1, causal_left_pad, hidden_size]
173
+
174
+ new_causal_cache_list = list()
175
+ for idx, causal_conv in enumerate(self.causal_conv_list):
176
+ chunk, new_causal_cache = causal_conv.forward(
177
+ inputs=chunk, causal_cache=causal_cache[idx: idx+1] if causal_cache is not None else None
178
+ )
179
+ new_causal_cache_list.append(new_causal_cache)
180
+
181
+ new_causal_cache = torch.cat(new_causal_cache_list, dim=0)
182
+ return chunk, new_causal_cache
183
+
184
+ def forward_chunk_by_chunk(self, inputs: torch.Tensor):
185
+ # inputs shape: [batch_size, 1, time_steps, hidden_size]
186
+ # batch_size = 1
187
+
188
+ batch_size, channels, time_steps, hidden_size = inputs.shape
189
+
190
+ causal_cache = None
191
+
192
+ outputs = []
193
+ for idx in range(0, time_steps, 1):
194
+ begin = idx
195
+ end = begin + self.total_causal_right_pad + 1
196
+ chunk_xs = inputs[:, :, begin:end, :]
197
+
198
+ ys, attention_cache = self.forward_chunk(
199
+ chunk=chunk_xs,
200
+ causal_cache=causal_cache,
201
+ )
202
+ # ys shape: [batch_size, channels, self.total_causal_right_pad + 1 , hidden_size]
203
+ ys = ys[:, :, :1, :]
204
+
205
+ # ys shape: [batch_size, chunk_size, hidden_size]
206
+ outputs.append(ys)
207
+
208
+ ys = torch.cat(outputs, 2)
209
+ return ys
210
+
211
+
212
+ def main2():
213
+ conv = CausalConv2d(
214
+ in_channels=1,
215
+ out_channels=64,
216
+ kernel_size=3,
217
+ bias=False,
218
+ separable=True,
219
+ f_stride=1,
220
+ lookahead=0,
221
+ )
222
+
223
+ spec = torch.randn(size=(1, 1, 200, 64), dtype=torch.float32)
224
+ # spec shape: [batch_size, 1, time_steps, hidden_size]
225
+ cache = torch.randn(size=(1, 1, conv.causal_left_pad, 64), dtype=torch.float32)
226
+
227
+ output, _ = conv.forward(spec)
228
+ print(output.shape)
229
+
230
+ output, _ = conv.forward(spec, cache)
231
+ print(output.shape)
232
+
233
+ return
234
+
235
+
236
+ def main():
237
+ causal = CausalConv2dEncoder(
238
+ in_channels=1,
239
+ out_channels=64,
240
+ kernel_size=3,
241
+ bias=False,
242
+ separable=True,
243
+ f_stride=1,
244
+ lookahead=0,
245
+ num_layers=3,
246
+ )
247
+
248
+ spec = torch.randn(size=(1, 1, 200, 64), dtype=torch.float32)
249
+ # spec shape: [batch_size, 1, time_steps, hidden_size]
250
+
251
+ output = causal.forward(spec)
252
+ print(output.shape)
253
+
254
+ output = causal.forward_chunk_by_chunk(spec)
255
+ print(output.shape)
256
+
257
+ return
258
+
259
+
260
+ if __name__ == '__main__':
261
+ main()
toolbox/torchaudio/models/nx_clean_unet/configuration_nx_clean_unet.py CHANGED
@@ -20,6 +20,15 @@ class NXCleanUNetConfig(PretrainedConfig):
20
  down_sampling_kernel_size: int = 4,
21
  down_sampling_stride: int = 2,
22
 
 
 
 
 
 
 
 
 
 
23
  tsfm_hidden_size: int = 256,
24
  tsfm_attention_heads: int = 4,
25
  tsfm_num_blocks: int = 6,
@@ -56,6 +65,15 @@ class NXCleanUNetConfig(PretrainedConfig):
56
  self.down_sampling_kernel_size = down_sampling_kernel_size
57
  self.down_sampling_stride = down_sampling_stride
58
 
 
 
 
 
 
 
 
 
 
59
  self.tsfm_hidden_size = tsfm_hidden_size
60
  self.tsfm_attention_heads = tsfm_attention_heads
61
  self.tsfm_num_blocks = tsfm_num_blocks
 
20
  down_sampling_kernel_size: int = 4,
21
  down_sampling_stride: int = 2,
22
 
23
+ causal_in_channels: int = 64,
24
+ causal_out_channels: int = 64,
25
+ causal_kernel_size: int = 3,
26
+ causal_bias: bool = False,
27
+ causal_separable: bool = True,
28
+ causal_f_stride: int = 1,
29
+ # causal_lookahead: int = 0,
30
+ causal_num_layers: int = 3,
31
+
32
  tsfm_hidden_size: int = 256,
33
  tsfm_attention_heads: int = 4,
34
  tsfm_num_blocks: int = 6,
 
65
  self.down_sampling_kernel_size = down_sampling_kernel_size
66
  self.down_sampling_stride = down_sampling_stride
67
 
68
+ self.causal_in_channels = causal_in_channels
69
+ self.causal_out_channels = causal_out_channels
70
+ self.causal_kernel_size = causal_kernel_size
71
+ self.causal_bias = causal_bias
72
+ self.causal_separable = causal_separable
73
+ self.causal_f_stride = causal_f_stride
74
+ # self.causal_lookahead = causal_lookahead
75
+ self.causal_num_layers = causal_num_layers
76
+
77
  self.tsfm_hidden_size = tsfm_hidden_size
78
  self.tsfm_attention_heads = tsfm_attention_heads
79
  self.tsfm_num_blocks = tsfm_num_blocks
toolbox/torchaudio/models/nx_clean_unet/enhanced_audio.wav CHANGED
Binary files a/toolbox/torchaudio/models/nx_clean_unet/enhanced_audio.wav and b/toolbox/torchaudio/models/nx_clean_unet/enhanced_audio.wav differ
 
toolbox/torchaudio/models/nx_clean_unet/inference_nx_clean_unet.py CHANGED
@@ -62,6 +62,7 @@ class InferenceNXCleanUNet(object):
62
 
63
  with torch.no_grad():
64
  enhanced_audios = self.model.forward_chunk_by_chunk(noisy_audios)
 
65
  # enhanced_audio shape: [batch_size, n_samples]
66
  # enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
67
 
@@ -70,16 +71,16 @@ class InferenceNXCleanUNet(object):
70
  return enhanced_audio
71
 
72
  def main():
73
- model_zip_file = project_path / "trained_models/nx-clean-unet-44-epoch.zip"
74
  infer_nx_clean_unet = InferenceNXCleanUNet(model_zip_file)
75
 
76
  sample_rate = 8000
77
- noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_3.wav"
78
  noisy_audio, _ = librosa.load(
79
  noisy_audio_file.as_posix(),
80
  sr=sample_rate,
81
  )
82
- # noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
83
  noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
84
  noisy_audio = noisy_audio.unsqueeze(dim=0)
85
 
 
62
 
63
  with torch.no_grad():
64
  enhanced_audios = self.model.forward_chunk_by_chunk(noisy_audios)
65
+ # enhanced_audios = self.model.forward(noisy_audios)
66
  # enhanced_audio shape: [batch_size, n_samples]
67
  # enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
68
 
 
71
  return enhanced_audio
72
 
73
  def main():
74
+ model_zip_file = project_path / "trained_models/nx-clean-unet-14-epoch.zip"
75
  infer_nx_clean_unet = InferenceNXCleanUNet(model_zip_file)
76
 
77
  sample_rate = 8000
78
+ noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav"
79
  noisy_audio, _ = librosa.load(
80
  noisy_audio_file.as_posix(),
81
  sr=sample_rate,
82
  )
83
+ noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
84
  noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
85
  noisy_audio = noisy_audio.unsqueeze(dim=0)
86
 
toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py CHANGED
@@ -11,6 +11,7 @@ from torch.nn import functional as F
11
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
12
  from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig
13
  from toolbox.torchaudio.models.nx_clean_unet.transformer.transformer import TransformerEncoder
 
14
 
15
 
16
  class DownSamplingBlock(nn.Module):
@@ -166,6 +167,16 @@ class NXCleanUNet(nn.Module):
166
  kernel_size=config.down_sampling_kernel_size,
167
  stride=config.down_sampling_stride,
168
  )
 
 
 
 
 
 
 
 
 
 
169
  self.transformer = TransformerEncoder(
170
  input_size=config.down_sampling_hidden_channels,
171
  hidden_size=config.tsfm_hidden_size,
 
11
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
12
  from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig
13
  from toolbox.torchaudio.models.nx_clean_unet.transformer.transformer import TransformerEncoder
14
+ from toolbox.torchaudio.models.nx_clean_unet.causal_convolution.causal_conv2d import CausalConv2dEncoder
15
 
16
 
17
  class DownSamplingBlock(nn.Module):
 
167
  kernel_size=config.down_sampling_kernel_size,
168
  stride=config.down_sampling_stride,
169
  )
170
+ self.causal_encoder = CausalConv2dEncoder(
171
+ in_channels=config.causal_in_channels,
172
+ out_channels=config.causal_out_channels,
173
+ kernel_size=config.causal_kernel_size,
174
+ bias=config.causal_bias,
175
+ separable=config.causal_separable,
176
+ f_stride=config.causal_f_stride,
177
+ lookahead=0,
178
+ num_layers=config.causal_num_layers,
179
+ )
180
  self.transformer = TransformerEncoder(
181
  input_size=config.down_sampling_hidden_channels,
182
  hidden_size=config.tsfm_hidden_size,
toolbox/torchaudio/models/nx_clean_unet/transformer/attention.py CHANGED
@@ -245,7 +245,7 @@ class RelativeMultiHeadSelfAttention(nn.Module):
245
 
246
 
247
  def main():
248
- rel_attention = RelativeMultiHeadedAttention(n_head=4, n_feat=256, dropout_rate=0.1)
249
 
250
  x = torch.ones(size=(1, 200, 256), dtype=torch.float32)
251
  xt, new_cache = rel_attention.forward(x, x, x)
 
245
 
246
 
247
  def main():
248
+ rel_attention = RelativeMultiHeadSelfAttention(n_head=4, n_feat=256, dropout_rate=0.1)
249
 
250
  x = torch.ones(size=(1, 200, 256), dtype=torch.float32)
251
  xt, new_cache = rel_attention.forward(x, x, x)
toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py CHANGED
@@ -5,6 +5,7 @@ from typing import Dict, Optional, Tuple, List, Union
5
 
6
  import torch
7
  import torch.nn as nn
 
8
 
9
  from toolbox.torchaudio.models.nx_clean_unet.transformer.mask import subsequent_chunk_mask
10
  from toolbox.torchaudio.models.nx_clean_unet.transformer.attention import MultiHeadSelfAttention, RelativeMultiHeadSelfAttention
@@ -69,7 +70,7 @@ class TransformerEncoderLayer(nn.Module):
69
  def forward(
70
  self,
71
  x: torch.Tensor,
72
- mask: torch.Tensor,
73
  attention_cache: torch.Tensor = None,
74
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
75
  """
@@ -175,40 +176,31 @@ class TransformerEncoder(nn.Module):
175
 
176
  def forward_chunk(self,
177
  xs: torch.Tensor,
178
- offset: int,
179
- attention_mask: torch.Tensor = None,
180
  attention_cache: torch.Tensor = None,
181
  ) -> Tuple[torch.Tensor, torch.Tensor]:
182
  """
183
  Forward just one chunk.
184
  :param xs: torch.Tensor. chunk input, with shape (b=1, time, mel-dim),
185
  where `time == (chunk_size - 1) * subsample_rate + subsample.right_context + 1`
186
- :param offset: int. current offset in encoder output timestamp.
187
- :param attention_mask:
188
- :param attention_cache: torch.Tensor. cache tensor for KEY & VALUE in
189
- transformer/conformer attention, with shape
190
- (elayers, head, cache_t1, d_k * 2), where
191
- `head * d_k == hidden-dim` and
192
- `cache_t1 == chunk_size * num_decoding_left_chunks`.
193
  :return:
194
  """
195
  # xs shape: [batch_size, time_steps, input_size]
196
  xs = self.input_linear.forward(xs)
197
  # xs shape: [batch_size, time_steps, hidden_size]
198
 
199
- xs, position_embedding = self.positional_encoding.forward(xs, offset=offset)
200
- # xs shape: [batch_size, time_steps, hidden_size]
201
- # position_embedding shape: [1, time_steps, hidden_size]
202
-
203
  r_att_cache = []
204
  for idx, encoder_layer in enumerate(self.encoder_layer_list):
205
  xs, new_att_cache = encoder_layer.forward(
206
- x=xs, mask=attention_mask,
207
- position_embedding=position_embedding,
208
- attention_cache=attention_cache[idx: idx+1],
209
  )
210
- r_att_cache.append(new_att_cache[:, :, self.chunk_size:, :])
211
- # r_att_cache.append(new_att_cache)
 
 
 
212
 
213
  r_att_cache = torch.cat(r_att_cache, dim=0)
214
 
@@ -221,25 +213,28 @@ class TransformerEncoder(nn.Module):
221
 
222
  batch_size, time_steps, _ = xs.shape
223
 
224
- # [num_blocks, attention_heads, num_left_chunks, dim]
225
- # attention_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
226
- attention_cache: torch.Tensor = torch.zeros((6, 8, 128, 256), device=xs.device)
227
- attention_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
228
 
229
  outputs = []
230
- for idx in range(0, time_steps - self.chunk_size + 1, self.chunk_size):
231
- begin = idx * self.chunk_size
232
- end = begin + self.chunk_size
233
  chunk_xs = xs[:, begin:end, :]
 
234
 
235
  ys, attention_cache = self.forward_chunk(
236
- xs=chunk_xs, attention_mask=attention_mask,
237
- offset=0, attention_cache=attention_cache
 
238
  )
 
 
239
 
240
- # xs shape: [batch_size, chunk_size, hidden_size]
241
  ys = self.output_linear.forward(ys)
242
- # xs shape: [batch_size, chunk_size, input_size]
243
 
244
  outputs.append(ys)
245
 
 
5
 
6
  import torch
7
  import torch.nn as nn
8
+ from fontTools.subset import prune_post_subset
9
 
10
  from toolbox.torchaudio.models.nx_clean_unet.transformer.mask import subsequent_chunk_mask
11
  from toolbox.torchaudio.models.nx_clean_unet.transformer.attention import MultiHeadSelfAttention, RelativeMultiHeadSelfAttention
 
70
  def forward(
71
  self,
72
  x: torch.Tensor,
73
+ mask: torch.Tensor = None,
74
  attention_cache: torch.Tensor = None,
75
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
76
  """
 
176
 
177
  def forward_chunk(self,
178
  xs: torch.Tensor,
179
+ max_att_cache_length: int,
 
180
  attention_cache: torch.Tensor = None,
181
  ) -> Tuple[torch.Tensor, torch.Tensor]:
182
  """
183
  Forward just one chunk.
184
  :param xs: torch.Tensor. chunk input, with shape (b=1, time, mel-dim),
185
  where `time == (chunk_size - 1) * subsample_rate + subsample.right_context + 1`
186
+ :param max_att_cache_length:
187
+ :param attention_cache: torch.Tensor.
 
 
 
 
 
188
  :return:
189
  """
190
  # xs shape: [batch_size, time_steps, input_size]
191
  xs = self.input_linear.forward(xs)
192
  # xs shape: [batch_size, time_steps, hidden_size]
193
 
 
 
 
 
194
  r_att_cache = []
195
  for idx, encoder_layer in enumerate(self.encoder_layer_list):
196
  xs, new_att_cache = encoder_layer.forward(
197
+ x=xs, attention_cache=attention_cache[idx: idx+1] if attention_cache is not None else None,
 
 
198
  )
199
+ if new_att_cache.size(2) > max_att_cache_length:
200
+ begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
201
+ end = self.num_right_chunks * self.chunk_size
202
+ new_att_cache = new_att_cache[:, :, -begin:-end, :]
203
+ r_att_cache.append(new_att_cache)
204
 
205
  r_att_cache = torch.cat(r_att_cache, dim=0)
206
 
 
213
 
214
  batch_size, time_steps, _ = xs.shape
215
 
216
+ # attention_cache shape: [num_blocks, attention_heads, self.num_left_chunks * self.chunk_size, n_heads * d_k * 2]
217
+ max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
218
+ attention_cache = None
 
219
 
220
  outputs = []
221
+ for idx in range(0, time_steps - self.chunk_size, self.chunk_size):
222
+ begin = idx
223
+ end = begin + self.chunk_size * (self.num_right_chunks + 1)
224
  chunk_xs = xs[:, begin:end, :]
225
+ # print(f"begin: {begin}, end: {end}, length: {chunk_xs.size(1)}")
226
 
227
  ys, attention_cache = self.forward_chunk(
228
+ xs=chunk_xs,
229
+ max_att_cache_length=max_att_cache_length,
230
+ attention_cache=attention_cache,
231
  )
232
+ # ys shape: [batch_size, self.chunk_size * (self.num_right_chunks + 1), hidden_size]
233
+ ys = ys[:, :self.chunk_size, :]
234
 
235
+ # ys shape: [batch_size, chunk_size, hidden_size]
236
  ys = self.output_linear.forward(ys)
237
+ # ys shape: [batch_size, chunk_size, input_size]
238
 
239
  outputs.append(ys)
240
 
toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml CHANGED
@@ -5,29 +5,38 @@ segment_size: 16000
5
  n_fft: 512
6
  win_size: 200
7
  hop_size: 80
 
8
 
9
  # 2**down_sampling_num_layers,
10
- # 例如 2**5=32 就意味着 32个值在降采样之后是一个时间步,
11
- # 则一步是 32/sample_rate = 0.004秒。
12
- # 那么 tsfm_chunk_size=4 则为16ms,tsfm_chunk_size=8 则为32ms
13
  # 假设每次向左看1秒,向右看30ms,则:
14
- # tsfm_chunk_size=1,tsfm_num_left_chunks=256,tsfm_num_right_chunks=8
15
- # tsfm_chunk_size=4,tsfm_num_left_chunks=64,tsfm_num_right_chunks=2
16
- # tsfm_chunk_size=8,tsfm_num_left_chunks=32,tsfm_num_right_chunks=1
17
- down_sampling_num_layers: 5
18
  down_sampling_in_channels: 1
19
  down_sampling_hidden_channels: 64
20
  down_sampling_kernel_size: 4
21
  down_sampling_stride: 2
22
 
 
 
 
 
 
 
 
 
23
  tsfm_hidden_size: 256
24
  tsfm_attention_heads: 8
25
  tsfm_num_blocks: 6
26
  tsfm_dropout_rate: 0.1
27
  tsfm_max_length: 512
28
- tsfm_chunk_size: 4
29
- tsfm_num_left_chunks: 64
30
- tsfm_num_right_chunks: 2
31
 
32
  discriminator_dim: 32
33
  discriminator_in_channel: 2
 
5
  n_fft: 512
6
  win_size: 200
7
  hop_size: 80
8
+ # 因为 hop_size 取 80,则相当于 stft 的时间步是 10ms 一步,所以降采样也考虑到差不多的分辨率。
9
 
10
  # 2**down_sampling_num_layers,
11
+ # 例如 2**6=64 就意味着 64 个值在降采样之后是一个时间步,
12
+ # 则一步是 64/sample_rate = 0.008秒。
13
+ # 那么 tsfm_chunk_size=2 则为16ms,tsfm_chunk_size=4 则为32ms
14
  # 假设每次向左看1秒,向右看30ms,则:
15
+ # tsfm_chunk_size=1,tsfm_num_left_chunks=128,tsfm_num_right_chunks=4
16
+ # tsfm_chunk_size=2,tsfm_num_left_chunks=64,tsfm_num_right_chunks=2
17
+ # tsfm_chunk_size=4,tsfm_num_left_chunks=32,tsfm_num_right_chunks=1
18
+ down_sampling_num_layers: 6
19
  down_sampling_in_channels: 1
20
  down_sampling_hidden_channels: 64
21
  down_sampling_kernel_size: 4
22
  down_sampling_stride: 2
23
 
24
+ causal_in_channels: 64
25
+ causal_out_channels: 64
26
+ causal_kernel_size: 3
27
+ causal_bias: false
28
+ causal_separable: true
29
+ causal_f_stride: 1
30
+ causal_num_layers: 3
31
+
32
  tsfm_hidden_size: 256
33
  tsfm_attention_heads: 8
34
  tsfm_num_blocks: 6
35
  tsfm_dropout_rate: 0.1
36
  tsfm_max_length: 512
37
+ tsfm_chunk_size: 1
38
+ tsfm_num_left_chunks: 128
39
+ tsfm_num_right_chunks: 4
40
 
41
  discriminator_dim: 32
42
  discriminator_in_channel: 2