Spaces:
Running
Running
update
Browse files- examples/dfnet/step_2_train_model.py +4 -13
- examples/dfnet/yaml/config.yaml +1 -1
- examples/dtln/yaml/config.yaml +2 -2
- main.py +9 -3
- toolbox/torchaudio/models/dfnet/inference_dfnet.py +115 -0
- toolbox/torchaudio/models/dfnet/yaml/config.yaml +74 -0
- toolbox/torchaudio/modules/conv_stft.py +7 -14
- toolbox/torchaudio/modules/utils/__init__.py +6 -0
- toolbox/torchaudio/modules/utils/ema.py +12 -0
examples/dfnet/step_2_train_model.py
CHANGED
@@ -187,18 +187,12 @@ def main():
|
|
187 |
if last_step_idx != -1:
|
188 |
logger.info(f"resume from steps-{last_step_idx}.")
|
189 |
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
190 |
-
optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
|
191 |
|
192 |
logger.info(f"load state dict for model.")
|
193 |
with open(model_pt.as_posix(), "rb") as f:
|
194 |
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
195 |
model.load_state_dict(state_dict, strict=True)
|
196 |
|
197 |
-
logger.info(f"load state dict for optimizer.")
|
198 |
-
with open(optimizer_pth.as_posix(), "rb") as f:
|
199 |
-
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
200 |
-
optimizer.load_state_dict(state_dict)
|
201 |
-
|
202 |
if config.lr_scheduler == "CosineAnnealingLR":
|
203 |
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
204 |
optimizer,
|
@@ -270,14 +264,14 @@ def main():
|
|
270 |
clean_audios: torch.Tensor = clean_audios.to(device)
|
271 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
272 |
|
273 |
-
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
274 |
|
275 |
mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
|
276 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
277 |
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
278 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
279 |
|
280 |
-
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss +
|
281 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
282 |
logger.info(f"find nan or inf in loss.")
|
283 |
continue
|
@@ -341,14 +335,14 @@ def main():
|
|
341 |
clean_audios: torch.Tensor = clean_audios.to(device)
|
342 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
343 |
|
344 |
-
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
345 |
|
346 |
mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
|
347 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
348 |
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
349 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
350 |
|
351 |
-
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss +
|
352 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
353 |
logger.info(f"find nan or inf in loss.")
|
354 |
continue
|
@@ -410,9 +404,6 @@ def main():
|
|
410 |
model_to_delete: Path = model_list.pop(0)
|
411 |
shutil.rmtree(model_to_delete.as_posix())
|
412 |
|
413 |
-
# save optim
|
414 |
-
torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
|
415 |
-
|
416 |
# save metric
|
417 |
if best_metric is None:
|
418 |
best_epoch_idx = epoch_idx
|
|
|
187 |
if last_step_idx != -1:
|
188 |
logger.info(f"resume from steps-{last_step_idx}.")
|
189 |
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
|
|
190 |
|
191 |
logger.info(f"load state dict for model.")
|
192 |
with open(model_pt.as_posix(), "rb") as f:
|
193 |
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
194 |
model.load_state_dict(state_dict, strict=True)
|
195 |
|
|
|
|
|
|
|
|
|
|
|
196 |
if config.lr_scheduler == "CosineAnnealingLR":
|
197 |
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
198 |
optimizer,
|
|
|
264 |
clean_audios: torch.Tensor = clean_audios.to(device)
|
265 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
266 |
|
267 |
+
est_spec, est_wav, est_mask, lsnr, erb_encoder_h = model.forward(noisy_audios)
|
268 |
|
269 |
mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
|
270 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
271 |
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
272 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
273 |
|
274 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.3 * lsnr_loss
|
275 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
276 |
logger.info(f"find nan or inf in loss.")
|
277 |
continue
|
|
|
335 |
clean_audios: torch.Tensor = clean_audios.to(device)
|
336 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
337 |
|
338 |
+
est_spec, est_wav, est_mask, lsnr, erb_encoder_h = model.forward(noisy_audios)
|
339 |
|
340 |
mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
|
341 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
342 |
mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
|
343 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
344 |
|
345 |
+
loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss + 0.3 * lsnr_loss
|
346 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
347 |
logger.info(f"find nan or inf in loss.")
|
348 |
continue
|
|
|
404 |
model_to_delete: Path = model_list.pop(0)
|
405 |
shutil.rmtree(model_to_delete.as_posix())
|
406 |
|
|
|
|
|
|
|
407 |
# save metric
|
408 |
if best_metric is None:
|
409 |
best_epoch_idx = epoch_idx
|
examples/dfnet/yaml/config.yaml
CHANGED
@@ -68,7 +68,7 @@ seed: 1234
|
|
68 |
|
69 |
num_workers: 8
|
70 |
batch_size: 64
|
71 |
-
eval_steps:
|
72 |
|
73 |
# runtime
|
74 |
use_post_filter: true
|
|
|
68 |
|
69 |
num_workers: 8
|
70 |
batch_size: 64
|
71 |
+
eval_steps: 10000
|
72 |
|
73 |
# runtime
|
74 |
use_post_filter: true
|
examples/dtln/yaml/config.yaml
CHANGED
@@ -24,6 +24,6 @@ max_epochs: 100
|
|
24 |
clip_grad_norm: 10.0
|
25 |
seed: 1234
|
26 |
|
27 |
-
batch_size:
|
28 |
num_workers: 4
|
29 |
-
eval_steps:
|
|
|
24 |
clip_grad_norm: 10.0
|
25 |
seed: 1234
|
26 |
|
27 |
+
batch_size: 64
|
28 |
num_workers: 4
|
29 |
+
eval_steps: 15000
|
main.py
CHANGED
@@ -62,10 +62,10 @@ def shell(cmd: str):
|
|
62 |
|
63 |
|
64 |
denoise_engines = {
|
65 |
-
"
|
66 |
-
"infer_cls":
|
67 |
"kwargs": {
|
68 |
-
"pretrained_model_path_or_zip_file": (project_path / "trained_models/
|
69 |
}
|
70 |
},
|
71 |
"frcrn-dns3": {
|
@@ -74,6 +74,12 @@ denoise_engines = {
|
|
74 |
"pretrained_model_path_or_zip_file": (project_path / "trained_models/frcrn-dns3.zip").as_posix()
|
75 |
}
|
76 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
}
|
78 |
|
79 |
|
|
|
62 |
|
63 |
|
64 |
denoise_engines = {
|
65 |
+
"dfnet-nx-dns3": {
|
66 |
+
"infer_cls": InferenceFRCRN,
|
67 |
"kwargs": {
|
68 |
+
"pretrained_model_path_or_zip_file": (project_path / "trained_models/dfnet-nx-dns3.zip").as_posix()
|
69 |
}
|
70 |
},
|
71 |
"frcrn-dns3": {
|
|
|
74 |
"pretrained_model_path_or_zip_file": (project_path / "trained_models/frcrn-dns3.zip").as_posix()
|
75 |
}
|
76 |
},
|
77 |
+
"mpnet-nx-speech": {
|
78 |
+
"infer_cls": InferenceMPNet,
|
79 |
+
"kwargs": {
|
80 |
+
"pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-nx-speech.zip").as_posix()
|
81 |
+
}
|
82 |
+
},
|
83 |
}
|
84 |
|
85 |
|
toolbox/torchaudio/models/dfnet/inference_dfnet.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import logging
|
4 |
+
from pathlib import Path
|
5 |
+
import shutil
|
6 |
+
import tempfile, time
|
7 |
+
import zipfile
|
8 |
+
|
9 |
+
import librosa
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torchaudio
|
13 |
+
|
14 |
+
torch.set_num_threads(1)
|
15 |
+
|
16 |
+
from project_settings import project_path
|
17 |
+
from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
|
18 |
+
from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNetPretrainedModel, MODEL_FILE
|
19 |
+
|
20 |
+
logger = logging.getLogger("toolbox")
|
21 |
+
|
22 |
+
|
23 |
+
class InferenceDfNet(object):
|
24 |
+
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
|
25 |
+
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
26 |
+
self.device = torch.device(device)
|
27 |
+
|
28 |
+
logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
|
29 |
+
config, model = self.load_models(self.pretrained_model_path_or_zip_file)
|
30 |
+
logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
|
31 |
+
|
32 |
+
self.config = config
|
33 |
+
self.model = model
|
34 |
+
self.model.to(device)
|
35 |
+
self.model.eval()
|
36 |
+
|
37 |
+
def load_models(self, model_path: str):
|
38 |
+
model_path = Path(model_path)
|
39 |
+
if model_path.name.endswith(".zip"):
|
40 |
+
with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
|
41 |
+
out_root = Path(tempfile.gettempdir()) / "nx_denoise"
|
42 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
43 |
+
f_zip.extractall(path=out_root)
|
44 |
+
model_path = out_root / model_path.stem
|
45 |
+
|
46 |
+
config = DfNetConfig.from_pretrained(
|
47 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
48 |
+
)
|
49 |
+
model = DfNetPretrainedModel.from_pretrained(
|
50 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
51 |
+
)
|
52 |
+
model.to(self.device)
|
53 |
+
model.eval()
|
54 |
+
|
55 |
+
shutil.rmtree(model_path)
|
56 |
+
return config, model
|
57 |
+
|
58 |
+
def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray:
|
59 |
+
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
|
60 |
+
noisy_audio = noisy_audio.unsqueeze(dim=0)
|
61 |
+
|
62 |
+
# noisy_audio shape: [batch_size, n_samples]
|
63 |
+
enhanced_audio = self.enhancement_by_tensor(noisy_audio)
|
64 |
+
# enhanced_audio shape: [channels, num_samples]
|
65 |
+
enhanced_audio = enhanced_audio[0]
|
66 |
+
# enhanced_audio shape: [num_samples]
|
67 |
+
return enhanced_audio.cpu().numpy()
|
68 |
+
|
69 |
+
def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
|
70 |
+
if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
|
71 |
+
raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
|
72 |
+
|
73 |
+
# noisy_audio shape: [batch_size, num_samples]
|
74 |
+
noisy_audios = noisy_audio.to(self.device)
|
75 |
+
|
76 |
+
with torch.no_grad():
|
77 |
+
est_spec, est_wav, est_mask, lsnr = self.model.forward(noisy_audios)
|
78 |
+
|
79 |
+
# shape: [batch_size, num_samples]
|
80 |
+
enhanced_audio = torch.unsqueeze(est_wav, dim=1)
|
81 |
+
# shape: [batch_size, 1, num_samples]
|
82 |
+
|
83 |
+
enhanced_audio = enhanced_audio[0]
|
84 |
+
# shape: [channels, num_samples]
|
85 |
+
return enhanced_audio
|
86 |
+
|
87 |
+
|
88 |
+
def main():
|
89 |
+
model_zip_file = project_path / "trained_models/dfnet-nx-dns3.zip"
|
90 |
+
infer_model = InferenceDfNet(model_zip_file)
|
91 |
+
|
92 |
+
sample_rate = 8000
|
93 |
+
noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_3.wav"
|
94 |
+
noisy_audio, sample_rate = librosa.load(
|
95 |
+
noisy_audio_file.as_posix(),
|
96 |
+
sr=sample_rate,
|
97 |
+
)
|
98 |
+
duration = librosa.get_duration(y=noisy_audio, sr=sample_rate)
|
99 |
+
# noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
|
100 |
+
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
|
101 |
+
noisy_audio = noisy_audio.unsqueeze(dim=0)
|
102 |
+
|
103 |
+
begin = time.time()
|
104 |
+
enhanced_audio = infer_model.enhancement_by_tensor(noisy_audio)
|
105 |
+
time_cost = time.time() - begin
|
106 |
+
print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
|
107 |
+
|
108 |
+
filename = "enhanced_audio.wav"
|
109 |
+
torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
|
110 |
+
|
111 |
+
return
|
112 |
+
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
main()
|
toolbox/torchaudio/models/dfnet/yaml/config.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "dfnet"
|
2 |
+
|
3 |
+
# spec
|
4 |
+
sample_rate: 8000
|
5 |
+
nfft: 512
|
6 |
+
win_size: 200
|
7 |
+
hop_size: 80
|
8 |
+
|
9 |
+
spec_bins: 256
|
10 |
+
|
11 |
+
# model
|
12 |
+
conv_channels: 64
|
13 |
+
conv_kernel_size_input:
|
14 |
+
- 3
|
15 |
+
- 3
|
16 |
+
conv_kernel_size_inner:
|
17 |
+
- 1
|
18 |
+
- 3
|
19 |
+
conv_lookahead: 0
|
20 |
+
|
21 |
+
convt_kernel_size_inner:
|
22 |
+
- 1
|
23 |
+
- 3
|
24 |
+
|
25 |
+
embedding_hidden_size: 256
|
26 |
+
encoder_combine_op: "concat"
|
27 |
+
|
28 |
+
encoder_emb_skip_op: "none"
|
29 |
+
encoder_emb_linear_groups: 16
|
30 |
+
encoder_emb_hidden_size: 256
|
31 |
+
|
32 |
+
encoder_linear_groups: 32
|
33 |
+
|
34 |
+
decoder_emb_num_layers: 3
|
35 |
+
decoder_emb_skip_op: "none"
|
36 |
+
decoder_emb_linear_groups: 16
|
37 |
+
decoder_emb_hidden_size: 256
|
38 |
+
|
39 |
+
df_decoder_hidden_size: 256
|
40 |
+
df_num_layers: 2
|
41 |
+
df_order: 5
|
42 |
+
df_bins: 96
|
43 |
+
df_gru_skip: "grouped_linear"
|
44 |
+
df_decoder_linear_groups: 16
|
45 |
+
df_pathway_kernel_size_t: 5
|
46 |
+
df_lookahead: 2
|
47 |
+
|
48 |
+
# lsnr
|
49 |
+
n_frame: 3
|
50 |
+
lsnr_max: 30
|
51 |
+
lsnr_min: -15
|
52 |
+
norm_tau: 1.
|
53 |
+
|
54 |
+
# data
|
55 |
+
min_snr_db: -10
|
56 |
+
max_snr_db: 20
|
57 |
+
|
58 |
+
# train
|
59 |
+
lr: 0.001
|
60 |
+
lr_scheduler: "CosineAnnealingLR"
|
61 |
+
lr_scheduler_kwargs:
|
62 |
+
T_max: 250000
|
63 |
+
eta_min: 0.0001
|
64 |
+
|
65 |
+
max_epochs: 100
|
66 |
+
clip_grad_norm: 10.0
|
67 |
+
seed: 1234
|
68 |
+
|
69 |
+
num_workers: 8
|
70 |
+
batch_size: 64
|
71 |
+
eval_steps: 10000
|
72 |
+
|
73 |
+
# runtime
|
74 |
+
use_post_filter: true
|
toolbox/torchaudio/modules/conv_stft.py
CHANGED
@@ -141,6 +141,7 @@ class ConviSTFT(nn.Module):
|
|
141 |
# waveform = waveform / coff
|
142 |
return waveform
|
143 |
|
|
|
144 |
def forward_chunk(self,
|
145 |
spec: torch.Tensor,
|
146 |
waveform_cache: torch.Tensor = None,
|
@@ -163,22 +164,14 @@ class ConviSTFT(nn.Module):
|
|
163 |
overlap_size = self.win_size - self.hop_size
|
164 |
|
165 |
if waveform_cache is not None:
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
new_waveform_cache = waveform_current[:, :, -self.hop_size:]
|
170 |
-
else:
|
171 |
-
waveform_output = waveform_current[:, :, :-self.hop_size]
|
172 |
-
new_waveform_cache = waveform_current[:, :, -self.hop_size:]
|
173 |
|
174 |
if coff_cache is not None:
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
new_coff_cache = coff_current[:, :, -self.hop_size:]
|
179 |
-
else:
|
180 |
-
coff_output = coff_current[:, :, :-self.hop_size]
|
181 |
-
new_coff_cache = coff_current[:, :, -self.hop_size:]
|
182 |
|
183 |
waveform_output = waveform_output / (coff_output + 1e-8)
|
184 |
return waveform_output, new_waveform_cache, new_coff_cache
|
|
|
141 |
# waveform = waveform / coff
|
142 |
return waveform
|
143 |
|
144 |
+
@torch.no_grad()
|
145 |
def forward_chunk(self,
|
146 |
spec: torch.Tensor,
|
147 |
waveform_cache: torch.Tensor = None,
|
|
|
164 |
overlap_size = self.win_size - self.hop_size
|
165 |
|
166 |
if waveform_cache is not None:
|
167 |
+
waveform_current[:, :, :overlap_size] += waveform_cache
|
168 |
+
waveform_output = waveform_current[:, :, :self.hop_size]
|
169 |
+
new_waveform_cache = waveform_current[:, :, self.hop_size:]
|
|
|
|
|
|
|
|
|
170 |
|
171 |
if coff_cache is not None:
|
172 |
+
coff_current[:, :, :overlap_size] += coff_cache
|
173 |
+
coff_output = coff_current[:, :, :self.hop_size]
|
174 |
+
new_coff_cache = coff_current[:, :, self.hop_size:]
|
|
|
|
|
|
|
|
|
175 |
|
176 |
waveform_output = waveform_output / (coff_output + 1e-8)
|
177 |
return waveform_output, new_waveform_cache, new_coff_cache
|
toolbox/torchaudio/modules/utils/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/modules/utils/ema.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class ExponentialMovingAverage(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
|
11 |
+
if __name__ == "__main__":
|
12 |
+
pass
|