Spaces:
Running
Running
update
Browse files- examples/nx_clean_unet/yaml/config.yaml +8 -0
- toolbox/torchaudio/models/nx_clean_unet/causal_convolution/__init__.py +6 -0
- toolbox/torchaudio/models/nx_clean_unet/causal_convolution/causal_conv2d.py +261 -0
- toolbox/torchaudio/models/nx_clean_unet/configuration_nx_clean_unet.py +18 -0
- toolbox/torchaudio/models/nx_clean_unet/enhanced_audio.wav +0 -0
- toolbox/torchaudio/models/nx_clean_unet/inference_nx_clean_unet.py +4 -3
- toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py +11 -0
- toolbox/torchaudio/models/nx_clean_unet/transformer/attention.py +1 -1
- toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py +25 -30
- toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml +19 -10
examples/nx_clean_unet/yaml/config.yaml
CHANGED
@@ -12,6 +12,14 @@ down_sampling_hidden_channels: 64
|
|
12 |
down_sampling_kernel_size: 4
|
13 |
down_sampling_stride: 2
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
tsfm_hidden_size: 256
|
16 |
tsfm_attention_heads: 8
|
17 |
tsfm_num_blocks: 6
|
|
|
12 |
down_sampling_kernel_size: 4
|
13 |
down_sampling_stride: 2
|
14 |
|
15 |
+
causal_in_channels: 64
|
16 |
+
causal_out_channels: 64
|
17 |
+
causal_kernel_size: 3
|
18 |
+
causal_bias: false
|
19 |
+
causal_separable: true
|
20 |
+
causal_f_stride: 1
|
21 |
+
causal_num_layers: 3
|
22 |
+
|
23 |
tsfm_hidden_size: 256
|
24 |
tsfm_attention_heads: 8
|
25 |
tsfm_num_blocks: 6
|
toolbox/torchaudio/models/nx_clean_unet/causal_convolution/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == '__main__':
|
6 |
+
pass
|
toolbox/torchaudio/models/nx_clean_unet/causal_convolution/causal_conv2d.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
from typing import List, Optional, Union, Iterable
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
|
13 |
+
norm_layer_dict = {
|
14 |
+
"batch_norm_2d": torch.nn.BatchNorm2d
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
activation_layer_dict = {
|
19 |
+
"relu": torch.nn.ReLU,
|
20 |
+
"identity": torch.nn.Identity,
|
21 |
+
"sigmoid": torch.nn.Sigmoid,
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
class CausalConv2d(nn.Module):
|
26 |
+
def __init__(self,
|
27 |
+
in_channels: int,
|
28 |
+
out_channels: int,
|
29 |
+
kernel_size: Union[int, Iterable[int]],
|
30 |
+
f_stride: int = 1,
|
31 |
+
dilation: int = 1,
|
32 |
+
do_f_pad: bool = True,
|
33 |
+
bias: bool = True,
|
34 |
+
separable: bool = False,
|
35 |
+
norm_layer: str = "batch_norm_2d",
|
36 |
+
activation_layer: str = "relu",
|
37 |
+
lookahead: int = 0
|
38 |
+
):
|
39 |
+
super(CausalConv2d, self).__init__()
|
40 |
+
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
41 |
+
|
42 |
+
if do_f_pad:
|
43 |
+
f_pad = kernel_size[1] // 2 + dilation - 1
|
44 |
+
else:
|
45 |
+
f_pad = 0
|
46 |
+
|
47 |
+
self.causal_left_pad = kernel_size[0] - 1 - lookahead
|
48 |
+
self.causal_right_pad = lookahead
|
49 |
+
self.constant_pad = nn.ConstantPad2d(
|
50 |
+
padding=(0, 0, self.causal_left_pad, self.causal_right_pad),
|
51 |
+
value=0.0
|
52 |
+
)
|
53 |
+
|
54 |
+
groups = math.gcd(in_channels, out_channels) if separable else 1
|
55 |
+
self.conv1 = nn.Conv2d(
|
56 |
+
in_channels,
|
57 |
+
out_channels,
|
58 |
+
kernel_size=kernel_size,
|
59 |
+
padding=(0, f_pad),
|
60 |
+
stride=(1, f_stride),
|
61 |
+
dilation=(1, dilation),
|
62 |
+
groups=groups,
|
63 |
+
bias=bias,
|
64 |
+
)
|
65 |
+
|
66 |
+
self.conv2 = None
|
67 |
+
if not any([groups == 1, max(kernel_size) == 1]):
|
68 |
+
self.conv2 = nn.Conv2d(
|
69 |
+
out_channels,
|
70 |
+
out_channels,
|
71 |
+
kernel_size=1,
|
72 |
+
bias=False,
|
73 |
+
)
|
74 |
+
|
75 |
+
self.norm = None
|
76 |
+
if norm_layer is not None:
|
77 |
+
norm_layer = norm_layer_dict[norm_layer]
|
78 |
+
self.norm = norm_layer(out_channels)
|
79 |
+
|
80 |
+
self.activation = None
|
81 |
+
if activation_layer is not None:
|
82 |
+
activation_layer = activation_layer_dict[activation_layer]
|
83 |
+
self.activation = activation_layer()
|
84 |
+
|
85 |
+
def forward(self,
|
86 |
+
inputs: torch.Tensor,
|
87 |
+
causal_cache: torch.Tensor = None,
|
88 |
+
):
|
89 |
+
|
90 |
+
if causal_cache is None:
|
91 |
+
# inputs shape: [batch_size, 1, time_steps, hidden_size]
|
92 |
+
x = self.constant_pad.forward(inputs)
|
93 |
+
else:
|
94 |
+
# inputs shape: [batch_size, 1, time_steps + self.causal_right_pad, hidden_size]
|
95 |
+
# causal_cache shape: [batch_size, 1, self.causal_left_pad, hidden_size]
|
96 |
+
x = torch.concat(tensors=[causal_cache, inputs], dim=2)
|
97 |
+
# x shape: [batch_size, 1, time_steps2, hidden_size]
|
98 |
+
# time_steps2 = time_steps + self.causal_left_pad + self.causal_right_pad
|
99 |
+
|
100 |
+
x = self.conv1.forward(x)
|
101 |
+
# inputs shape: [batch_size, 1, time_steps, hidden_size]
|
102 |
+
|
103 |
+
if self.conv2:
|
104 |
+
x = self.conv2.forward(x)
|
105 |
+
|
106 |
+
if self.norm:
|
107 |
+
x = self.norm(x)
|
108 |
+
if self.activation:
|
109 |
+
x = self.activation(x)
|
110 |
+
|
111 |
+
causal_cache = x[:, :, -self.causal_left_pad:, :]
|
112 |
+
|
113 |
+
# inputs shape: [batch_size, 1, time_steps, hidden_size]
|
114 |
+
return x, causal_cache
|
115 |
+
|
116 |
+
|
117 |
+
class CausalConv2dEncoder(nn.Module):
|
118 |
+
def __init__(self,
|
119 |
+
in_channels: int,
|
120 |
+
out_channels: int,
|
121 |
+
kernel_size: Union[int, Iterable[int]],
|
122 |
+
f_stride: int = 1,
|
123 |
+
dilation: int = 1,
|
124 |
+
do_f_pad: bool = True,
|
125 |
+
bias: bool = True,
|
126 |
+
separable: bool = False,
|
127 |
+
norm_layer: str = "batch_norm_2d",
|
128 |
+
activation_layer: str = "relu",
|
129 |
+
lookahead: int = 0,
|
130 |
+
num_layers: int = 5,
|
131 |
+
):
|
132 |
+
super(CausalConv2dEncoder, self).__init__()
|
133 |
+
self.num_layers = num_layers
|
134 |
+
|
135 |
+
self.total_causal_left_pad = 0
|
136 |
+
self.total_causal_right_pad = 0
|
137 |
+
|
138 |
+
self.causal_conv_list: List[CausalConv2d] = nn.ModuleList(modules=[])
|
139 |
+
for i_layer in range(num_layers):
|
140 |
+
conv = CausalConv2d(
|
141 |
+
in_channels=in_channels,
|
142 |
+
out_channels=out_channels,
|
143 |
+
kernel_size=kernel_size,
|
144 |
+
f_stride=f_stride,
|
145 |
+
dilation=dilation,
|
146 |
+
do_f_pad=do_f_pad,
|
147 |
+
bias=bias,
|
148 |
+
separable=separable,
|
149 |
+
norm_layer=norm_layer,
|
150 |
+
activation_layer=activation_layer,
|
151 |
+
lookahead=lookahead,
|
152 |
+
)
|
153 |
+
self.causal_conv_list.append(conv)
|
154 |
+
|
155 |
+
self.total_causal_left_pad += conv.causal_left_pad
|
156 |
+
self.total_causal_right_pad += conv.causal_right_pad
|
157 |
+
|
158 |
+
in_channels = out_channels
|
159 |
+
|
160 |
+
def forward(self, inputs: torch.Tensor):
|
161 |
+
# inputs shape: [batch_size, 1, time_steps, hidden_size]
|
162 |
+
|
163 |
+
x = inputs
|
164 |
+
for layer in self.causal_conv_list:
|
165 |
+
x, _ = layer.forward(x)
|
166 |
+
return x
|
167 |
+
|
168 |
+
def forward_chunk(self,
|
169 |
+
chunk: torch.Tensor,
|
170 |
+
causal_cache: torch.Tensor = None,
|
171 |
+
):
|
172 |
+
# causal_cache shape: [self.num_layers, 1, causal_left_pad, hidden_size]
|
173 |
+
|
174 |
+
new_causal_cache_list = list()
|
175 |
+
for idx, causal_conv in enumerate(self.causal_conv_list):
|
176 |
+
chunk, new_causal_cache = causal_conv.forward(
|
177 |
+
inputs=chunk, causal_cache=causal_cache[idx: idx+1] if causal_cache is not None else None
|
178 |
+
)
|
179 |
+
new_causal_cache_list.append(new_causal_cache)
|
180 |
+
|
181 |
+
new_causal_cache = torch.cat(new_causal_cache_list, dim=0)
|
182 |
+
return chunk, new_causal_cache
|
183 |
+
|
184 |
+
def forward_chunk_by_chunk(self, inputs: torch.Tensor):
|
185 |
+
# inputs shape: [batch_size, 1, time_steps, hidden_size]
|
186 |
+
# batch_size = 1
|
187 |
+
|
188 |
+
batch_size, channels, time_steps, hidden_size = inputs.shape
|
189 |
+
|
190 |
+
causal_cache = None
|
191 |
+
|
192 |
+
outputs = []
|
193 |
+
for idx in range(0, time_steps, 1):
|
194 |
+
begin = idx
|
195 |
+
end = begin + self.total_causal_right_pad + 1
|
196 |
+
chunk_xs = inputs[:, :, begin:end, :]
|
197 |
+
|
198 |
+
ys, attention_cache = self.forward_chunk(
|
199 |
+
chunk=chunk_xs,
|
200 |
+
causal_cache=causal_cache,
|
201 |
+
)
|
202 |
+
# ys shape: [batch_size, channels, self.total_causal_right_pad + 1 , hidden_size]
|
203 |
+
ys = ys[:, :, :1, :]
|
204 |
+
|
205 |
+
# ys shape: [batch_size, chunk_size, hidden_size]
|
206 |
+
outputs.append(ys)
|
207 |
+
|
208 |
+
ys = torch.cat(outputs, 2)
|
209 |
+
return ys
|
210 |
+
|
211 |
+
|
212 |
+
def main2():
|
213 |
+
conv = CausalConv2d(
|
214 |
+
in_channels=1,
|
215 |
+
out_channels=64,
|
216 |
+
kernel_size=3,
|
217 |
+
bias=False,
|
218 |
+
separable=True,
|
219 |
+
f_stride=1,
|
220 |
+
lookahead=0,
|
221 |
+
)
|
222 |
+
|
223 |
+
spec = torch.randn(size=(1, 1, 200, 64), dtype=torch.float32)
|
224 |
+
# spec shape: [batch_size, 1, time_steps, hidden_size]
|
225 |
+
cache = torch.randn(size=(1, 1, conv.causal_left_pad, 64), dtype=torch.float32)
|
226 |
+
|
227 |
+
output, _ = conv.forward(spec)
|
228 |
+
print(output.shape)
|
229 |
+
|
230 |
+
output, _ = conv.forward(spec, cache)
|
231 |
+
print(output.shape)
|
232 |
+
|
233 |
+
return
|
234 |
+
|
235 |
+
|
236 |
+
def main():
|
237 |
+
causal = CausalConv2dEncoder(
|
238 |
+
in_channels=1,
|
239 |
+
out_channels=64,
|
240 |
+
kernel_size=3,
|
241 |
+
bias=False,
|
242 |
+
separable=True,
|
243 |
+
f_stride=1,
|
244 |
+
lookahead=0,
|
245 |
+
num_layers=3,
|
246 |
+
)
|
247 |
+
|
248 |
+
spec = torch.randn(size=(1, 1, 200, 64), dtype=torch.float32)
|
249 |
+
# spec shape: [batch_size, 1, time_steps, hidden_size]
|
250 |
+
|
251 |
+
output = causal.forward(spec)
|
252 |
+
print(output.shape)
|
253 |
+
|
254 |
+
output = causal.forward_chunk_by_chunk(spec)
|
255 |
+
print(output.shape)
|
256 |
+
|
257 |
+
return
|
258 |
+
|
259 |
+
|
260 |
+
if __name__ == '__main__':
|
261 |
+
main()
|
toolbox/torchaudio/models/nx_clean_unet/configuration_nx_clean_unet.py
CHANGED
@@ -20,6 +20,15 @@ class NXCleanUNetConfig(PretrainedConfig):
|
|
20 |
down_sampling_kernel_size: int = 4,
|
21 |
down_sampling_stride: int = 2,
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
tsfm_hidden_size: int = 256,
|
24 |
tsfm_attention_heads: int = 4,
|
25 |
tsfm_num_blocks: int = 6,
|
@@ -56,6 +65,15 @@ class NXCleanUNetConfig(PretrainedConfig):
|
|
56 |
self.down_sampling_kernel_size = down_sampling_kernel_size
|
57 |
self.down_sampling_stride = down_sampling_stride
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
self.tsfm_hidden_size = tsfm_hidden_size
|
60 |
self.tsfm_attention_heads = tsfm_attention_heads
|
61 |
self.tsfm_num_blocks = tsfm_num_blocks
|
|
|
20 |
down_sampling_kernel_size: int = 4,
|
21 |
down_sampling_stride: int = 2,
|
22 |
|
23 |
+
causal_in_channels: int = 64,
|
24 |
+
causal_out_channels: int = 64,
|
25 |
+
causal_kernel_size: int = 3,
|
26 |
+
causal_bias: bool = False,
|
27 |
+
causal_separable: bool = True,
|
28 |
+
causal_f_stride: int = 1,
|
29 |
+
# causal_lookahead: int = 0,
|
30 |
+
causal_num_layers: int = 3,
|
31 |
+
|
32 |
tsfm_hidden_size: int = 256,
|
33 |
tsfm_attention_heads: int = 4,
|
34 |
tsfm_num_blocks: int = 6,
|
|
|
65 |
self.down_sampling_kernel_size = down_sampling_kernel_size
|
66 |
self.down_sampling_stride = down_sampling_stride
|
67 |
|
68 |
+
self.causal_in_channels = causal_in_channels
|
69 |
+
self.causal_out_channels = causal_out_channels
|
70 |
+
self.causal_kernel_size = causal_kernel_size
|
71 |
+
self.causal_bias = causal_bias
|
72 |
+
self.causal_separable = causal_separable
|
73 |
+
self.causal_f_stride = causal_f_stride
|
74 |
+
# self.causal_lookahead = causal_lookahead
|
75 |
+
self.causal_num_layers = causal_num_layers
|
76 |
+
|
77 |
self.tsfm_hidden_size = tsfm_hidden_size
|
78 |
self.tsfm_attention_heads = tsfm_attention_heads
|
79 |
self.tsfm_num_blocks = tsfm_num_blocks
|
toolbox/torchaudio/models/nx_clean_unet/enhanced_audio.wav
CHANGED
Binary files a/toolbox/torchaudio/models/nx_clean_unet/enhanced_audio.wav and b/toolbox/torchaudio/models/nx_clean_unet/enhanced_audio.wav differ
|
|
toolbox/torchaudio/models/nx_clean_unet/inference_nx_clean_unet.py
CHANGED
@@ -62,6 +62,7 @@ class InferenceNXCleanUNet(object):
|
|
62 |
|
63 |
with torch.no_grad():
|
64 |
enhanced_audios = self.model.forward_chunk_by_chunk(noisy_audios)
|
|
|
65 |
# enhanced_audio shape: [batch_size, n_samples]
|
66 |
# enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
67 |
|
@@ -70,16 +71,16 @@ class InferenceNXCleanUNet(object):
|
|
70 |
return enhanced_audio
|
71 |
|
72 |
def main():
|
73 |
-
model_zip_file = project_path / "trained_models/nx-clean-unet-
|
74 |
infer_nx_clean_unet = InferenceNXCleanUNet(model_zip_file)
|
75 |
|
76 |
sample_rate = 8000
|
77 |
-
noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-
|
78 |
noisy_audio, _ = librosa.load(
|
79 |
noisy_audio_file.as_posix(),
|
80 |
sr=sample_rate,
|
81 |
)
|
82 |
-
|
83 |
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
|
84 |
noisy_audio = noisy_audio.unsqueeze(dim=0)
|
85 |
|
|
|
62 |
|
63 |
with torch.no_grad():
|
64 |
enhanced_audios = self.model.forward_chunk_by_chunk(noisy_audios)
|
65 |
+
# enhanced_audios = self.model.forward(noisy_audios)
|
66 |
# enhanced_audio shape: [batch_size, n_samples]
|
67 |
# enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
68 |
|
|
|
71 |
return enhanced_audio
|
72 |
|
73 |
def main():
|
74 |
+
model_zip_file = project_path / "trained_models/nx-clean-unet-14-epoch.zip"
|
75 |
infer_nx_clean_unet = InferenceNXCleanUNet(model_zip_file)
|
76 |
|
77 |
sample_rate = 8000
|
78 |
+
noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav"
|
79 |
noisy_audio, _ = librosa.load(
|
80 |
noisy_audio_file.as_posix(),
|
81 |
sr=sample_rate,
|
82 |
)
|
83 |
+
noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
|
84 |
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
|
85 |
noisy_audio = noisy_audio.unsqueeze(dim=0)
|
86 |
|
toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py
CHANGED
@@ -11,6 +11,7 @@ from torch.nn import functional as F
|
|
11 |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
12 |
from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig
|
13 |
from toolbox.torchaudio.models.nx_clean_unet.transformer.transformer import TransformerEncoder
|
|
|
14 |
|
15 |
|
16 |
class DownSamplingBlock(nn.Module):
|
@@ -166,6 +167,16 @@ class NXCleanUNet(nn.Module):
|
|
166 |
kernel_size=config.down_sampling_kernel_size,
|
167 |
stride=config.down_sampling_stride,
|
168 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
self.transformer = TransformerEncoder(
|
170 |
input_size=config.down_sampling_hidden_channels,
|
171 |
hidden_size=config.tsfm_hidden_size,
|
|
|
11 |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
12 |
from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig
|
13 |
from toolbox.torchaudio.models.nx_clean_unet.transformer.transformer import TransformerEncoder
|
14 |
+
from toolbox.torchaudio.models.nx_clean_unet.causal_convolution.causal_conv2d import CausalConv2dEncoder
|
15 |
|
16 |
|
17 |
class DownSamplingBlock(nn.Module):
|
|
|
167 |
kernel_size=config.down_sampling_kernel_size,
|
168 |
stride=config.down_sampling_stride,
|
169 |
)
|
170 |
+
self.causal_encoder = CausalConv2dEncoder(
|
171 |
+
in_channels=config.causal_in_channels,
|
172 |
+
out_channels=config.causal_out_channels,
|
173 |
+
kernel_size=config.causal_kernel_size,
|
174 |
+
bias=config.causal_bias,
|
175 |
+
separable=config.causal_separable,
|
176 |
+
f_stride=config.causal_f_stride,
|
177 |
+
lookahead=0,
|
178 |
+
num_layers=config.causal_num_layers,
|
179 |
+
)
|
180 |
self.transformer = TransformerEncoder(
|
181 |
input_size=config.down_sampling_hidden_channels,
|
182 |
hidden_size=config.tsfm_hidden_size,
|
toolbox/torchaudio/models/nx_clean_unet/transformer/attention.py
CHANGED
@@ -245,7 +245,7 @@ class RelativeMultiHeadSelfAttention(nn.Module):
|
|
245 |
|
246 |
|
247 |
def main():
|
248 |
-
rel_attention =
|
249 |
|
250 |
x = torch.ones(size=(1, 200, 256), dtype=torch.float32)
|
251 |
xt, new_cache = rel_attention.forward(x, x, x)
|
|
|
245 |
|
246 |
|
247 |
def main():
|
248 |
+
rel_attention = RelativeMultiHeadSelfAttention(n_head=4, n_feat=256, dropout_rate=0.1)
|
249 |
|
250 |
x = torch.ones(size=(1, 200, 256), dtype=torch.float32)
|
251 |
xt, new_cache = rel_attention.forward(x, x, x)
|
toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py
CHANGED
@@ -5,6 +5,7 @@ from typing import Dict, Optional, Tuple, List, Union
|
|
5 |
|
6 |
import torch
|
7 |
import torch.nn as nn
|
|
|
8 |
|
9 |
from toolbox.torchaudio.models.nx_clean_unet.transformer.mask import subsequent_chunk_mask
|
10 |
from toolbox.torchaudio.models.nx_clean_unet.transformer.attention import MultiHeadSelfAttention, RelativeMultiHeadSelfAttention
|
@@ -69,7 +70,7 @@ class TransformerEncoderLayer(nn.Module):
|
|
69 |
def forward(
|
70 |
self,
|
71 |
x: torch.Tensor,
|
72 |
-
mask: torch.Tensor,
|
73 |
attention_cache: torch.Tensor = None,
|
74 |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
75 |
"""
|
@@ -175,40 +176,31 @@ class TransformerEncoder(nn.Module):
|
|
175 |
|
176 |
def forward_chunk(self,
|
177 |
xs: torch.Tensor,
|
178 |
-
|
179 |
-
attention_mask: torch.Tensor = None,
|
180 |
attention_cache: torch.Tensor = None,
|
181 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
182 |
"""
|
183 |
Forward just one chunk.
|
184 |
:param xs: torch.Tensor. chunk input, with shape (b=1, time, mel-dim),
|
185 |
where `time == (chunk_size - 1) * subsample_rate + subsample.right_context + 1`
|
186 |
-
:param
|
187 |
-
:param
|
188 |
-
:param attention_cache: torch.Tensor. cache tensor for KEY & VALUE in
|
189 |
-
transformer/conformer attention, with shape
|
190 |
-
(elayers, head, cache_t1, d_k * 2), where
|
191 |
-
`head * d_k == hidden-dim` and
|
192 |
-
`cache_t1 == chunk_size * num_decoding_left_chunks`.
|
193 |
:return:
|
194 |
"""
|
195 |
# xs shape: [batch_size, time_steps, input_size]
|
196 |
xs = self.input_linear.forward(xs)
|
197 |
# xs shape: [batch_size, time_steps, hidden_size]
|
198 |
|
199 |
-
xs, position_embedding = self.positional_encoding.forward(xs, offset=offset)
|
200 |
-
# xs shape: [batch_size, time_steps, hidden_size]
|
201 |
-
# position_embedding shape: [1, time_steps, hidden_size]
|
202 |
-
|
203 |
r_att_cache = []
|
204 |
for idx, encoder_layer in enumerate(self.encoder_layer_list):
|
205 |
xs, new_att_cache = encoder_layer.forward(
|
206 |
-
x=xs,
|
207 |
-
position_embedding=position_embedding,
|
208 |
-
attention_cache=attention_cache[idx: idx+1],
|
209 |
)
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
212 |
|
213 |
r_att_cache = torch.cat(r_att_cache, dim=0)
|
214 |
|
@@ -221,25 +213,28 @@ class TransformerEncoder(nn.Module):
|
|
221 |
|
222 |
batch_size, time_steps, _ = xs.shape
|
223 |
|
224 |
-
# [num_blocks, attention_heads, num_left_chunks,
|
225 |
-
|
226 |
-
attention_cache
|
227 |
-
attention_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool, device=xs.device)
|
228 |
|
229 |
outputs = []
|
230 |
-
for idx in range(0, time_steps - self.chunk_size
|
231 |
-
begin = idx
|
232 |
-
end = begin + self.chunk_size
|
233 |
chunk_xs = xs[:, begin:end, :]
|
|
|
234 |
|
235 |
ys, attention_cache = self.forward_chunk(
|
236 |
-
xs=chunk_xs,
|
237 |
-
|
|
|
238 |
)
|
|
|
|
|
239 |
|
240 |
-
#
|
241 |
ys = self.output_linear.forward(ys)
|
242 |
-
#
|
243 |
|
244 |
outputs.append(ys)
|
245 |
|
|
|
5 |
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
+
from fontTools.subset import prune_post_subset
|
9 |
|
10 |
from toolbox.torchaudio.models.nx_clean_unet.transformer.mask import subsequent_chunk_mask
|
11 |
from toolbox.torchaudio.models.nx_clean_unet.transformer.attention import MultiHeadSelfAttention, RelativeMultiHeadSelfAttention
|
|
|
70 |
def forward(
|
71 |
self,
|
72 |
x: torch.Tensor,
|
73 |
+
mask: torch.Tensor = None,
|
74 |
attention_cache: torch.Tensor = None,
|
75 |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
76 |
"""
|
|
|
176 |
|
177 |
def forward_chunk(self,
|
178 |
xs: torch.Tensor,
|
179 |
+
max_att_cache_length: int,
|
|
|
180 |
attention_cache: torch.Tensor = None,
|
181 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
182 |
"""
|
183 |
Forward just one chunk.
|
184 |
:param xs: torch.Tensor. chunk input, with shape (b=1, time, mel-dim),
|
185 |
where `time == (chunk_size - 1) * subsample_rate + subsample.right_context + 1`
|
186 |
+
:param max_att_cache_length:
|
187 |
+
:param attention_cache: torch.Tensor.
|
|
|
|
|
|
|
|
|
|
|
188 |
:return:
|
189 |
"""
|
190 |
# xs shape: [batch_size, time_steps, input_size]
|
191 |
xs = self.input_linear.forward(xs)
|
192 |
# xs shape: [batch_size, time_steps, hidden_size]
|
193 |
|
|
|
|
|
|
|
|
|
194 |
r_att_cache = []
|
195 |
for idx, encoder_layer in enumerate(self.encoder_layer_list):
|
196 |
xs, new_att_cache = encoder_layer.forward(
|
197 |
+
x=xs, attention_cache=attention_cache[idx: idx+1] if attention_cache is not None else None,
|
|
|
|
|
198 |
)
|
199 |
+
if new_att_cache.size(2) > max_att_cache_length:
|
200 |
+
begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
|
201 |
+
end = self.num_right_chunks * self.chunk_size
|
202 |
+
new_att_cache = new_att_cache[:, :, -begin:-end, :]
|
203 |
+
r_att_cache.append(new_att_cache)
|
204 |
|
205 |
r_att_cache = torch.cat(r_att_cache, dim=0)
|
206 |
|
|
|
213 |
|
214 |
batch_size, time_steps, _ = xs.shape
|
215 |
|
216 |
+
# attention_cache shape: [num_blocks, attention_heads, self.num_left_chunks * self.chunk_size, n_heads * d_k * 2]
|
217 |
+
max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
|
218 |
+
attention_cache = None
|
|
|
219 |
|
220 |
outputs = []
|
221 |
+
for idx in range(0, time_steps - self.chunk_size, self.chunk_size):
|
222 |
+
begin = idx
|
223 |
+
end = begin + self.chunk_size * (self.num_right_chunks + 1)
|
224 |
chunk_xs = xs[:, begin:end, :]
|
225 |
+
# print(f"begin: {begin}, end: {end}, length: {chunk_xs.size(1)}")
|
226 |
|
227 |
ys, attention_cache = self.forward_chunk(
|
228 |
+
xs=chunk_xs,
|
229 |
+
max_att_cache_length=max_att_cache_length,
|
230 |
+
attention_cache=attention_cache,
|
231 |
)
|
232 |
+
# ys shape: [batch_size, self.chunk_size * (self.num_right_chunks + 1), hidden_size]
|
233 |
+
ys = ys[:, :self.chunk_size, :]
|
234 |
|
235 |
+
# ys shape: [batch_size, chunk_size, hidden_size]
|
236 |
ys = self.output_linear.forward(ys)
|
237 |
+
# ys shape: [batch_size, chunk_size, input_size]
|
238 |
|
239 |
outputs.append(ys)
|
240 |
|
toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml
CHANGED
@@ -5,29 +5,38 @@ segment_size: 16000
|
|
5 |
n_fft: 512
|
6 |
win_size: 200
|
7 |
hop_size: 80
|
|
|
8 |
|
9 |
# 2**down_sampling_num_layers,
|
10 |
-
# 例如 2**
|
11 |
-
# 则一步是
|
12 |
-
# 那么 tsfm_chunk_size=
|
13 |
# 假设每次向左看1秒,向右看30ms,则:
|
14 |
-
# tsfm_chunk_size=1,tsfm_num_left_chunks=
|
15 |
-
# tsfm_chunk_size=
|
16 |
-
# tsfm_chunk_size=
|
17 |
-
down_sampling_num_layers:
|
18 |
down_sampling_in_channels: 1
|
19 |
down_sampling_hidden_channels: 64
|
20 |
down_sampling_kernel_size: 4
|
21 |
down_sampling_stride: 2
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
tsfm_hidden_size: 256
|
24 |
tsfm_attention_heads: 8
|
25 |
tsfm_num_blocks: 6
|
26 |
tsfm_dropout_rate: 0.1
|
27 |
tsfm_max_length: 512
|
28 |
-
tsfm_chunk_size:
|
29 |
-
tsfm_num_left_chunks:
|
30 |
-
tsfm_num_right_chunks:
|
31 |
|
32 |
discriminator_dim: 32
|
33 |
discriminator_in_channel: 2
|
|
|
5 |
n_fft: 512
|
6 |
win_size: 200
|
7 |
hop_size: 80
|
8 |
+
# 因为 hop_size 取 80,则相当于 stft 的时间步是 10ms 一步,所以降采样也考虑到差不多的分辨率。
|
9 |
|
10 |
# 2**down_sampling_num_layers,
|
11 |
+
# 例如 2**6=64 就意味着 64 个值在降采样之后是一个时间步,
|
12 |
+
# 则一步是 64/sample_rate = 0.008秒。
|
13 |
+
# 那么 tsfm_chunk_size=2 则为16ms,tsfm_chunk_size=4 则为32ms
|
14 |
# 假设每次向左看1秒,向右看30ms,则:
|
15 |
+
# tsfm_chunk_size=1,tsfm_num_left_chunks=128,tsfm_num_right_chunks=4
|
16 |
+
# tsfm_chunk_size=2,tsfm_num_left_chunks=64,tsfm_num_right_chunks=2
|
17 |
+
# tsfm_chunk_size=4,tsfm_num_left_chunks=32,tsfm_num_right_chunks=1
|
18 |
+
down_sampling_num_layers: 6
|
19 |
down_sampling_in_channels: 1
|
20 |
down_sampling_hidden_channels: 64
|
21 |
down_sampling_kernel_size: 4
|
22 |
down_sampling_stride: 2
|
23 |
|
24 |
+
causal_in_channels: 64
|
25 |
+
causal_out_channels: 64
|
26 |
+
causal_kernel_size: 3
|
27 |
+
causal_bias: false
|
28 |
+
causal_separable: true
|
29 |
+
causal_f_stride: 1
|
30 |
+
causal_num_layers: 3
|
31 |
+
|
32 |
tsfm_hidden_size: 256
|
33 |
tsfm_attention_heads: 8
|
34 |
tsfm_num_blocks: 6
|
35 |
tsfm_dropout_rate: 0.1
|
36 |
tsfm_max_length: 512
|
37 |
+
tsfm_chunk_size: 1
|
38 |
+
tsfm_num_left_chunks: 128
|
39 |
+
tsfm_num_right_chunks: 4
|
40 |
|
41 |
discriminator_dim: 32
|
42 |
discriminator_in_channel: 2
|