Spaces:
Running
Running
add microphone audio input
Browse files- Dockerfile +3 -0
- examples/dfnet/step_2_train_model.py +6 -0
- examples/dfnet/yaml/config.yaml +1 -1
- examples/dtln/step_2_train_model.py +4 -0
- examples/dtln/yaml/config.yaml +14 -8
- examples/frcrn/step_2_train_model.py +4 -0
- examples/{simple_lstm_irm → lstm}/run.sh +0 -0
- examples/{simple_lstm_irm → lstm}/step_1_prepare_data.py +0 -0
- examples/lstm/step_2_train_model.py +476 -0
- examples/{simple_lstm_irm → lstm}/step_3_evaluation.py +2 -2
- examples/mpnet/step_2_train_model.py +4 -0
- examples/simple_lstm_irm/step_2_train_model.py +0 -346
- toolbox/torchaudio/models/dfnet/configuration_dfnet.py +4 -0
- toolbox/torchaudio/models/dfnet/conv_stft.py +0 -148
- toolbox/torchaudio/models/dfnet/modeling_dfnet.py +86 -89
- toolbox/torchaudio/models/frcrn/conv_stft.py +2 -2
- toolbox/torchaudio/models/{simple_lstm_irm → lstm}/__init__.py +0 -0
- toolbox/torchaudio/models/lstm/configuration_lstm.py +73 -0
- toolbox/torchaudio/models/lstm/modeling_lstm.py +260 -0
- toolbox/torchaudio/models/lstm/yaml/config.yaml +32 -0
- toolbox/torchaudio/models/simple_lstm_irm/configuration_simple_lstm_irm.py +0 -38
- toolbox/torchaudio/models/simple_lstm_irm/modeling_simple_lstm_irm.py +0 -133
- toolbox/torchaudio/models/simple_lstm_irm/yaml/config.yaml +0 -14
- toolbox/torchaudio/modules/conv_stft.py +132 -13
- toolbox/torchaudio/modules/freq_bands/__init__.py +6 -0
- toolbox/torchaudio/modules/freq_bands/erb_bands.py +173 -0
Dockerfile
CHANGED
@@ -4,6 +4,9 @@ WORKDIR /code
|
|
4 |
|
5 |
COPY . /code
|
6 |
|
|
|
|
|
|
|
7 |
RUN pip install --upgrade pip
|
8 |
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
9 |
|
|
|
4 |
|
5 |
COPY . /code
|
6 |
|
7 |
+
RUN apt-get update
|
8 |
+
RUN apt-get install -y ffmpeg build-essential
|
9 |
+
|
10 |
RUN pip install --upgrade pip
|
11 |
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
12 |
|
examples/dfnet/step_2_train_model.py
CHANGED
@@ -15,6 +15,8 @@ import sys
|
|
15 |
import shutil
|
16 |
from typing import List
|
17 |
|
|
|
|
|
18 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
19 |
sys.path.append(os.path.join(pwd, "../../"))
|
20 |
|
@@ -243,7 +245,11 @@ def main():
|
|
243 |
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
244 |
|
245 |
logger.info("training")
|
|
|
246 |
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
|
|
|
|
|
|
247 |
# train
|
248 |
model.train()
|
249 |
|
|
|
15 |
import shutil
|
16 |
from typing import List
|
17 |
|
18 |
+
from fontTools.varLib.plot import stops
|
19 |
+
|
20 |
pwd = os.path.abspath(os.path.dirname(__file__))
|
21 |
sys.path.append(os.path.join(pwd, "../../"))
|
22 |
|
|
|
245 |
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
246 |
|
247 |
logger.info("training")
|
248 |
+
early_stop_flag = False
|
249 |
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
250 |
+
if early_stop_flag:
|
251 |
+
break
|
252 |
+
|
253 |
# train
|
254 |
model.train()
|
255 |
|
examples/dfnet/yaml/config.yaml
CHANGED
@@ -68,7 +68,7 @@ seed: 1234
|
|
68 |
|
69 |
num_workers: 8
|
70 |
batch_size: 32
|
71 |
-
eval_steps:
|
72 |
|
73 |
# runtime
|
74 |
use_post_filter: true
|
|
|
68 |
|
69 |
num_workers: 8
|
70 |
batch_size: 32
|
71 |
+
eval_steps: 25000
|
72 |
|
73 |
# runtime
|
74 |
use_post_filter: true
|
examples/dtln/step_2_train_model.py
CHANGED
@@ -235,7 +235,11 @@ def main():
|
|
235 |
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
236 |
|
237 |
logger.info("training")
|
|
|
238 |
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
|
|
|
|
|
|
239 |
# train
|
240 |
model.train()
|
241 |
|
|
|
235 |
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
236 |
|
237 |
logger.info("training")
|
238 |
+
early_stop_flag = False
|
239 |
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
240 |
+
if early_stop_flag:
|
241 |
+
break
|
242 |
+
|
243 |
# train
|
244 |
model.train()
|
245 |
|
examples/dtln/yaml/config.yaml
CHANGED
@@ -1,23 +1,29 @@
|
|
1 |
model_name: "DTLN"
|
2 |
|
|
|
3 |
sample_rate: 8000
|
4 |
fft_size: 256
|
5 |
hop_size: 128
|
6 |
win_type: hann
|
7 |
|
|
|
8 |
max_snr_db: 20
|
9 |
min_snr_db: -10
|
10 |
|
|
|
11 |
encoder_size: 256
|
12 |
|
13 |
-
|
14 |
-
batch_size: 4
|
15 |
-
num_workers: 4
|
16 |
-
seed: 1234
|
17 |
-
eval_steps: 25000
|
18 |
-
|
19 |
lr: 0.001
|
20 |
-
lr_scheduler: CosineAnnealingLR
|
21 |
-
lr_scheduler_kwargs:
|
|
|
|
|
22 |
|
|
|
23 |
clip_grad_norm: 10.0
|
|
|
|
|
|
|
|
|
|
|
|
1 |
model_name: "DTLN"
|
2 |
|
3 |
+
# spec
|
4 |
sample_rate: 8000
|
5 |
fft_size: 256
|
6 |
hop_size: 128
|
7 |
win_type: hann
|
8 |
|
9 |
+
# data
|
10 |
max_snr_db: 20
|
11 |
min_snr_db: -10
|
12 |
|
13 |
+
# model
|
14 |
encoder_size: 256
|
15 |
|
16 |
+
# train
|
|
|
|
|
|
|
|
|
|
|
17 |
lr: 0.001
|
18 |
+
lr_scheduler: "CosineAnnealingLR"
|
19 |
+
lr_scheduler_kwargs:
|
20 |
+
T_max: 250000
|
21 |
+
eta_min: 0.0001
|
22 |
|
23 |
+
max_epochs: 100
|
24 |
clip_grad_norm: 10.0
|
25 |
+
seed: 1234
|
26 |
+
|
27 |
+
batch_size: 32
|
28 |
+
num_workers: 4
|
29 |
+
eval_steps: 25000
|
examples/frcrn/step_2_train_model.py
CHANGED
@@ -238,7 +238,11 @@ def main():
|
|
238 |
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
239 |
|
240 |
logger.info("training")
|
|
|
241 |
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
|
|
|
|
|
|
242 |
# train
|
243 |
model.train()
|
244 |
|
|
|
238 |
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
239 |
|
240 |
logger.info("training")
|
241 |
+
early_stop_flag = False
|
242 |
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
243 |
+
if early_stop_flag:
|
244 |
+
break
|
245 |
+
|
246 |
# train
|
247 |
model.train()
|
248 |
|
examples/{simple_lstm_irm → lstm}/run.sh
RENAMED
File without changes
|
examples/{simple_lstm_irm → lstm}/step_1_prepare_data.py
RENAMED
File without changes
|
examples/lstm/step_2_train_model.py
ADDED
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/WenzheLiu-Speech/awesome-speech-enhancement
|
5 |
+
"""
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
from logging.handlers import TimedRotatingFileHandler
|
10 |
+
import os
|
11 |
+
import platform
|
12 |
+
from pathlib import Path
|
13 |
+
import random
|
14 |
+
import sys
|
15 |
+
import shutil
|
16 |
+
from typing import List
|
17 |
+
|
18 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
19 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
from torch.utils.data.dataloader import DataLoader
|
25 |
+
import torchaudio
|
26 |
+
from tqdm import tqdm
|
27 |
+
|
28 |
+
from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
|
29 |
+
from toolbox.torchaudio.metrics.pesq import run_pesq_score
|
30 |
+
from toolbox.torchaudio.models.lstm.configuration_lstm import LstmConfig
|
31 |
+
from toolbox.torchaudio.models.lstm.modeling_lstm import LstmPretrainedModel
|
32 |
+
|
33 |
+
|
34 |
+
def get_args():
|
35 |
+
parser = argparse.ArgumentParser()
|
36 |
+
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
37 |
+
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
38 |
+
parser.add_argument("--max_epochs", default=100, type=int)
|
39 |
+
|
40 |
+
parser.add_argument("--batch_size", default=64, type=int)
|
41 |
+
parser.add_argument("--learning_rate", default=1e-3, type=float)
|
42 |
+
parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
|
43 |
+
parser.add_argument("--patience", default=10, type=int)
|
44 |
+
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
45 |
+
parser.add_argument("--seed", default=0, type=int)
|
46 |
+
|
47 |
+
parser.add_argument("--config_file", default="config.yaml", type=str)
|
48 |
+
|
49 |
+
args = parser.parse_args()
|
50 |
+
return args
|
51 |
+
|
52 |
+
|
53 |
+
def logging_config(file_dir: str):
|
54 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
55 |
+
|
56 |
+
logging.basicConfig(format=fmt,
|
57 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
58 |
+
level=logging.INFO)
|
59 |
+
file_handler = TimedRotatingFileHandler(
|
60 |
+
filename=os.path.join(file_dir, "main.log"),
|
61 |
+
encoding="utf-8",
|
62 |
+
when="D",
|
63 |
+
interval=1,
|
64 |
+
backupCount=7
|
65 |
+
)
|
66 |
+
file_handler.setLevel(logging.INFO)
|
67 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
68 |
+
logger = logging.getLogger(__name__)
|
69 |
+
logger.addHandler(file_handler)
|
70 |
+
|
71 |
+
return logger
|
72 |
+
|
73 |
+
|
74 |
+
class CollateFunction(object):
|
75 |
+
def __init__(self,
|
76 |
+
n_fft: int = 512,
|
77 |
+
win_length: int = 200,
|
78 |
+
hop_length: int = 80,
|
79 |
+
window_fn: str = "hamming",
|
80 |
+
irm_beta: float = 1.0,
|
81 |
+
epsilon: float = 1e-8,
|
82 |
+
):
|
83 |
+
self.n_fft = n_fft
|
84 |
+
self.win_length = win_length
|
85 |
+
self.hop_length = hop_length
|
86 |
+
self.window_fn = window_fn
|
87 |
+
self.irm_beta = irm_beta
|
88 |
+
self.epsilon = epsilon
|
89 |
+
|
90 |
+
self.stft_mag = torchaudio.transforms.Spectrogram(
|
91 |
+
n_fft=self.n_fft,
|
92 |
+
win_length=self.win_length,
|
93 |
+
hop_length=self.hop_length,
|
94 |
+
power=1.0,
|
95 |
+
window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
|
96 |
+
)
|
97 |
+
self.stft_complex = torchaudio.transforms.Spectrogram(
|
98 |
+
n_fft=self.n_fft,
|
99 |
+
win_length=self.win_length,
|
100 |
+
hop_length=self.hop_length,
|
101 |
+
power=None,
|
102 |
+
window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
|
103 |
+
)
|
104 |
+
|
105 |
+
self.istft = torchaudio.transforms.InverseSpectrogram(
|
106 |
+
n_fft=self.n_fft,
|
107 |
+
win_length=self.win_length,
|
108 |
+
hop_length=self.hop_length,
|
109 |
+
window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
|
110 |
+
)
|
111 |
+
|
112 |
+
def __call__(self, batch: List[dict]):
|
113 |
+
mag_noisy_audios = list()
|
114 |
+
pha_noisy_audios = list()
|
115 |
+
irm_gth = list()
|
116 |
+
|
117 |
+
clean_audios = list()
|
118 |
+
|
119 |
+
for sample in batch:
|
120 |
+
noise_audio: torch.Tensor = sample["noise_wave"]
|
121 |
+
clean_audio: torch.Tensor = sample["speech_wave"]
|
122 |
+
noisy_audio: torch.Tensor = sample["mix_wave"]
|
123 |
+
snr_db: float = sample["snr_db"]
|
124 |
+
|
125 |
+
mag_noise = self.stft_mag.forward(noise_audio)
|
126 |
+
mag_clean = self.stft_mag.forward(clean_audio)
|
127 |
+
stft_noisy = self.stft_complex.forward(noisy_audio)
|
128 |
+
|
129 |
+
irm_clean = mag_clean / (mag_noise + mag_clean + self.epsilon)
|
130 |
+
irm_clean = torch.pow(irm_clean, self.irm_beta)
|
131 |
+
|
132 |
+
real = torch.real(stft_noisy)
|
133 |
+
imag = torch.imag(stft_noisy)
|
134 |
+
mag_noisy = torch.sqrt(real ** 2 + imag ** 2)
|
135 |
+
pha_noisy = torch.atan2(imag, real)
|
136 |
+
|
137 |
+
mag_noisy_audios.append(mag_noisy)
|
138 |
+
pha_noisy_audios.append(pha_noisy)
|
139 |
+
irm_gth.append(irm_clean)
|
140 |
+
clean_audios.append(clean_audio)
|
141 |
+
|
142 |
+
mag_noisy_audios = torch.stack(mag_noisy_audios)
|
143 |
+
pha_noisy_audios = torch.stack(pha_noisy_audios)
|
144 |
+
irm_gth = torch.stack(irm_gth)
|
145 |
+
clean_audios = torch.stack(clean_audios)
|
146 |
+
|
147 |
+
# assert
|
148 |
+
if torch.any(torch.isnan(mag_noisy_audios)):
|
149 |
+
raise AssertionError("nan in mag_noisy_audios Tensor")
|
150 |
+
if torch.any(torch.isnan(pha_noisy_audios)):
|
151 |
+
raise AssertionError("nan in pha_noisy_audios Tensor")
|
152 |
+
if torch.any(torch.isnan(irm_gth)):
|
153 |
+
raise AssertionError("nan in irm_gth Tensor")
|
154 |
+
if torch.any(torch.isnan(clean_audios)):
|
155 |
+
raise AssertionError("nan in clean_audios Tensor")
|
156 |
+
|
157 |
+
return mag_noisy_audios, pha_noisy_audios, irm_gth, clean_audios
|
158 |
+
|
159 |
+
def enhance(self, mag_noisy: torch.Tensor, pha_noisy: torch.Tensor, irm_speech: torch.Tensor):
|
160 |
+
mag_denoise = mag_noisy * irm_speech
|
161 |
+
stft_denoise = mag_denoise * torch.exp((1j * pha_noisy))
|
162 |
+
denoise = self.istft.forward(stft_denoise)
|
163 |
+
return denoise
|
164 |
+
|
165 |
+
|
166 |
+
collate_fn = CollateFunction()
|
167 |
+
|
168 |
+
|
169 |
+
def main():
|
170 |
+
args = get_args()
|
171 |
+
|
172 |
+
config = LstmConfig.from_pretrained(
|
173 |
+
pretrained_model_name_or_path=args.config_file,
|
174 |
+
)
|
175 |
+
|
176 |
+
serialization_dir = Path(args.serialization_dir)
|
177 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
178 |
+
|
179 |
+
logger = logging_config(serialization_dir)
|
180 |
+
|
181 |
+
random.seed(args.seed)
|
182 |
+
np.random.seed(args.seed)
|
183 |
+
torch.manual_seed(args.seed)
|
184 |
+
logger.info("set seed: {}".format(args.seed))
|
185 |
+
|
186 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
187 |
+
n_gpu = torch.cuda.device_count()
|
188 |
+
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
|
189 |
+
|
190 |
+
# datasets
|
191 |
+
logger.info("prepare datasets")
|
192 |
+
train_dataset = DenoiseJsonlDataset(
|
193 |
+
jsonl_file=args.train_dataset,
|
194 |
+
expected_sample_rate=config.sample_rate,
|
195 |
+
max_wave_value=32768.0,
|
196 |
+
min_snr_db=config.min_snr_db,
|
197 |
+
max_snr_db=config.max_snr_db,
|
198 |
+
# skip=225000,
|
199 |
+
)
|
200 |
+
valid_dataset = DenoiseJsonlDataset(
|
201 |
+
jsonl_file=args.valid_dataset,
|
202 |
+
expected_sample_rate=config.sample_rate,
|
203 |
+
max_wave_value=32768.0,
|
204 |
+
min_snr_db=config.min_snr_db,
|
205 |
+
max_snr_db=config.max_snr_db,
|
206 |
+
)
|
207 |
+
train_data_loader = DataLoader(
|
208 |
+
dataset=train_dataset,
|
209 |
+
batch_size=config.batch_size,
|
210 |
+
# shuffle=True,
|
211 |
+
sampler=None,
|
212 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
213 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
214 |
+
collate_fn=collate_fn,
|
215 |
+
pin_memory=False,
|
216 |
+
prefetch_factor=None if platform.system() == "Windows" else 2,
|
217 |
+
)
|
218 |
+
valid_data_loader = DataLoader(
|
219 |
+
dataset=valid_dataset,
|
220 |
+
batch_size=config.batch_size,
|
221 |
+
# shuffle=True,
|
222 |
+
sampler=None,
|
223 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
224 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
225 |
+
collate_fn=collate_fn,
|
226 |
+
pin_memory=False,
|
227 |
+
prefetch_factor=None if platform.system() == "Windows" else 2,
|
228 |
+
)
|
229 |
+
|
230 |
+
# models
|
231 |
+
logger.info(f"prepare models. config_file: {args.config_file}")
|
232 |
+
model = LstmPretrainedModel(
|
233 |
+
config=config,
|
234 |
+
)
|
235 |
+
model.to(device)
|
236 |
+
model.train()
|
237 |
+
|
238 |
+
# optimizer
|
239 |
+
logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
|
240 |
+
optimizer = torch.optim.AdamW(model.parameters(), config.lr)
|
241 |
+
|
242 |
+
# resume training
|
243 |
+
last_step_idx = -1
|
244 |
+
last_epoch = -1
|
245 |
+
for step_idx_str in serialization_dir.glob("steps-*"):
|
246 |
+
step_idx_str = Path(step_idx_str)
|
247 |
+
step_idx = step_idx_str.stem.split("-")[1]
|
248 |
+
step_idx = int(step_idx)
|
249 |
+
if step_idx > last_step_idx:
|
250 |
+
last_step_idx = step_idx
|
251 |
+
# last_epoch = 1
|
252 |
+
|
253 |
+
if last_step_idx != -1:
|
254 |
+
logger.info(f"resume from steps-{last_step_idx}.")
|
255 |
+
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
256 |
+
optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
|
257 |
+
|
258 |
+
logger.info(f"load state dict for model.")
|
259 |
+
with open(model_pt.as_posix(), "rb") as f:
|
260 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
261 |
+
model.load_state_dict(state_dict, strict=True)
|
262 |
+
|
263 |
+
logger.info(f"load state dict for optimizer.")
|
264 |
+
with open(optimizer_pth.as_posix(), "rb") as f:
|
265 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
266 |
+
optimizer.load_state_dict(state_dict)
|
267 |
+
|
268 |
+
if config.lr_scheduler == "CosineAnnealingLR":
|
269 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
270 |
+
optimizer,
|
271 |
+
last_epoch=last_epoch,
|
272 |
+
# T_max=10 * config.eval_steps,
|
273 |
+
# eta_min=0.01 * config.lr,
|
274 |
+
**config.lr_scheduler_kwargs,
|
275 |
+
)
|
276 |
+
elif config.lr_scheduler == "MultiStepLR":
|
277 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
278 |
+
optimizer,
|
279 |
+
last_epoch=last_epoch,
|
280 |
+
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
281 |
+
)
|
282 |
+
else:
|
283 |
+
raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
|
284 |
+
|
285 |
+
mse_loss_fn = nn.MSELoss(
|
286 |
+
reduction="mean",
|
287 |
+
).to(device)
|
288 |
+
|
289 |
+
# training loop
|
290 |
+
logger.info("training")
|
291 |
+
|
292 |
+
average_pesq_score = 1000000000
|
293 |
+
average_loss = 1000000000
|
294 |
+
|
295 |
+
model_list = list()
|
296 |
+
best_epoch_idx = None
|
297 |
+
best_step_idx = None
|
298 |
+
best_metric = None
|
299 |
+
patience_count = 0
|
300 |
+
|
301 |
+
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
302 |
+
|
303 |
+
logger.info("training")
|
304 |
+
early_stop_flag = False
|
305 |
+
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
306 |
+
if early_stop_flag:
|
307 |
+
break
|
308 |
+
|
309 |
+
# train
|
310 |
+
model.train()
|
311 |
+
|
312 |
+
total_pesq_score = 0.
|
313 |
+
total_loss = 0.
|
314 |
+
total_batches = 0.
|
315 |
+
|
316 |
+
progress_bar_train = tqdm(
|
317 |
+
initial=step_idx,
|
318 |
+
desc="Training; epoch: {}".format(epoch_idx),
|
319 |
+
)
|
320 |
+
for train_batch in train_data_loader:
|
321 |
+
mag_noisy_audios, pha_noisy_audios, irm_gth, clean_audios = train_batch
|
322 |
+
mag_noisy_audios = mag_noisy_audios.to(device)
|
323 |
+
pha_noisy_audios = pha_noisy_audios.to(device)
|
324 |
+
irm_gth = irm_gth.to(device)
|
325 |
+
clean_audios = clean_audios.to(device)
|
326 |
+
|
327 |
+
irm = model.forward(mag_noisy_audios)
|
328 |
+
denoise_audios = collate_fn.enhance(mag_noisy_audios, pha_noisy_audios, irm)
|
329 |
+
loss = mse_loss_fn.forward(irm, irm_gth)
|
330 |
+
|
331 |
+
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
332 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
333 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
334 |
+
|
335 |
+
optimizer.zero_grad()
|
336 |
+
loss.backward()
|
337 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
|
338 |
+
optimizer.step()
|
339 |
+
lr_scheduler.step()
|
340 |
+
|
341 |
+
total_pesq_score += pesq_score
|
342 |
+
total_loss += loss.item()
|
343 |
+
total_batches += 1
|
344 |
+
|
345 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
346 |
+
average_loss = round(total_loss / total_batches, 4)
|
347 |
+
|
348 |
+
progress_bar_train.update(1)
|
349 |
+
progress_bar_train.set_postfix({
|
350 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
351 |
+
"pesq_score": average_pesq_score,
|
352 |
+
"loss": average_loss,
|
353 |
+
})
|
354 |
+
|
355 |
+
# evaluation
|
356 |
+
step_idx += 1
|
357 |
+
if step_idx % config.eval_steps == 0:
|
358 |
+
with torch.no_grad():
|
359 |
+
torch.cuda.empty_cache()
|
360 |
+
|
361 |
+
total_pesq_score = 0.
|
362 |
+
total_loss = 0.
|
363 |
+
total_batches = 0.
|
364 |
+
|
365 |
+
progress_bar_train.close()
|
366 |
+
progress_bar_eval = tqdm(
|
367 |
+
desc="Evaluation; steps-{}k".format(int(step_idx / 1000)),
|
368 |
+
)
|
369 |
+
|
370 |
+
for eval_batch in valid_data_loader:
|
371 |
+
mag_noisy_audios, pha_noisy_audios, irm_gth, clean_audios = eval_batch
|
372 |
+
mag_noisy_audios = mag_noisy_audios.to(device)
|
373 |
+
pha_noisy_audios = pha_noisy_audios.to(device)
|
374 |
+
irm_gth = irm_gth.to(device)
|
375 |
+
clean_audios = clean_audios.to(device)
|
376 |
+
|
377 |
+
with torch.no_grad():
|
378 |
+
irm = model.forward(mag_noisy_audios)
|
379 |
+
denoise_audios = collate_fn.enhance(mag_noisy_audios, pha_noisy_audios, irm)
|
380 |
+
loss = mse_loss_fn.forward(irm, irm_gth)
|
381 |
+
|
382 |
+
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
383 |
+
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
384 |
+
pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
|
385 |
+
|
386 |
+
optimizer.zero_grad()
|
387 |
+
loss.backward()
|
388 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
|
389 |
+
optimizer.step()
|
390 |
+
lr_scheduler.step()
|
391 |
+
|
392 |
+
total_pesq_score += pesq_score
|
393 |
+
total_loss += loss.item()
|
394 |
+
total_batches += 1
|
395 |
+
|
396 |
+
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
397 |
+
average_loss = round(total_loss / total_batches, 4)
|
398 |
+
|
399 |
+
progress_bar_eval.update(1)
|
400 |
+
progress_bar_eval.set_postfix({
|
401 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
402 |
+
"pesq_score": average_pesq_score,
|
403 |
+
"loss": average_loss,
|
404 |
+
})
|
405 |
+
|
406 |
+
total_pesq_score = 0.
|
407 |
+
total_loss = 0.
|
408 |
+
total_batches = 0.
|
409 |
+
|
410 |
+
progress_bar_eval.close()
|
411 |
+
progress_bar_train = tqdm(
|
412 |
+
initial=progress_bar_train.n,
|
413 |
+
postfix=progress_bar_train.postfix,
|
414 |
+
desc=progress_bar_train.desc,
|
415 |
+
)
|
416 |
+
|
417 |
+
# save path
|
418 |
+
epoch_dir = serialization_dir / "epoch-{}".format(epoch_idx)
|
419 |
+
epoch_dir.mkdir(parents=True, exist_ok=False)
|
420 |
+
|
421 |
+
# save models
|
422 |
+
model.save_pretrained(epoch_dir.as_posix())
|
423 |
+
|
424 |
+
model_list.append(epoch_dir)
|
425 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
426 |
+
model_to_delete: Path = model_list.pop(0)
|
427 |
+
shutil.rmtree(model_to_delete.as_posix())
|
428 |
+
|
429 |
+
# save metric
|
430 |
+
if best_metric is None:
|
431 |
+
best_epoch_idx = epoch_idx
|
432 |
+
best_step_idx = step_idx
|
433 |
+
best_metric = average_pesq_score
|
434 |
+
elif average_pesq_score >= best_metric:
|
435 |
+
# great is better.
|
436 |
+
best_epoch_idx = epoch_idx
|
437 |
+
best_step_idx = step_idx
|
438 |
+
best_metric = average_pesq_score
|
439 |
+
else:
|
440 |
+
pass
|
441 |
+
|
442 |
+
metrics = {
|
443 |
+
"epoch_idx": epoch_idx,
|
444 |
+
"best_epoch_idx": best_epoch_idx,
|
445 |
+
"best_step_idx": best_step_idx,
|
446 |
+
"pesq_score": average_pesq_score,
|
447 |
+
"loss": average_loss,
|
448 |
+
}
|
449 |
+
metrics_filename = epoch_dir / "metrics_epoch.json"
|
450 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
451 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
452 |
+
|
453 |
+
# save best
|
454 |
+
best_dir = serialization_dir / "best"
|
455 |
+
if best_epoch_idx == epoch_idx:
|
456 |
+
if best_dir.exists():
|
457 |
+
shutil.rmtree(best_dir)
|
458 |
+
shutil.copytree(epoch_dir, best_dir)
|
459 |
+
|
460 |
+
# early stop
|
461 |
+
early_stop_flag = False
|
462 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
463 |
+
patience_count = 0
|
464 |
+
else:
|
465 |
+
patience_count += 1
|
466 |
+
if patience_count >= args.patience:
|
467 |
+
early_stop_flag = True
|
468 |
+
|
469 |
+
# early stop
|
470 |
+
if early_stop_flag:
|
471 |
+
break
|
472 |
+
return
|
473 |
+
|
474 |
+
|
475 |
+
if __name__ == '__main__':
|
476 |
+
main()
|
examples/{simple_lstm_irm → lstm}/step_3_evaluation.py
RENAMED
@@ -19,7 +19,7 @@ import torch.nn as nn
|
|
19 |
import torchaudio
|
20 |
from tqdm import tqdm
|
21 |
|
22 |
-
from toolbox.torchaudio.models.
|
23 |
|
24 |
|
25 |
def get_args():
|
@@ -147,7 +147,7 @@ def main():
|
|
147 |
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
|
148 |
|
149 |
logger.info("prepare model")
|
150 |
-
model =
|
151 |
pretrained_model_name_or_path=args.model_dir,
|
152 |
)
|
153 |
model.to(device)
|
|
|
19 |
import torchaudio
|
20 |
from tqdm import tqdm
|
21 |
|
22 |
+
from toolbox.torchaudio.models.lstm.modeling_lstm import LstmPretrainedModel
|
23 |
|
24 |
|
25 |
def get_args():
|
|
|
147 |
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
|
148 |
|
149 |
logger.info("prepare model")
|
150 |
+
model = LstmPretrainedModel.from_pretrained(
|
151 |
pretrained_model_name_or_path=args.model_dir,
|
152 |
)
|
153 |
model.to(device)
|
examples/mpnet/step_2_train_model.py
CHANGED
@@ -225,7 +225,11 @@ def main():
|
|
225 |
patience_count = 0
|
226 |
|
227 |
logger.info("training")
|
|
|
228 |
for idx_epoch in range(max(0, last_epoch+1), args.max_epochs):
|
|
|
|
|
|
|
229 |
# train
|
230 |
generator.train()
|
231 |
discriminator.train()
|
|
|
225 |
patience_count = 0
|
226 |
|
227 |
logger.info("training")
|
228 |
+
early_stop_flag = False
|
229 |
for idx_epoch in range(max(0, last_epoch+1), args.max_epochs):
|
230 |
+
if early_stop_flag:
|
231 |
+
break
|
232 |
+
|
233 |
# train
|
234 |
generator.train()
|
235 |
discriminator.train()
|
examples/simple_lstm_irm/step_2_train_model.py
DELETED
@@ -1,346 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
"""
|
4 |
-
https://github.com/WenzheLiu-Speech/awesome-speech-enhancement
|
5 |
-
"""
|
6 |
-
import argparse
|
7 |
-
import json
|
8 |
-
import logging
|
9 |
-
from logging.handlers import TimedRotatingFileHandler
|
10 |
-
import os
|
11 |
-
import platform
|
12 |
-
from pathlib import Path
|
13 |
-
import random
|
14 |
-
import sys
|
15 |
-
import shutil
|
16 |
-
from typing import List
|
17 |
-
|
18 |
-
pwd = os.path.abspath(os.path.dirname(__file__))
|
19 |
-
sys.path.append(os.path.join(pwd, "../../"))
|
20 |
-
|
21 |
-
import numpy as np
|
22 |
-
import torch
|
23 |
-
import torch.nn as nn
|
24 |
-
from torch.utils.data.dataloader import DataLoader
|
25 |
-
import torchaudio
|
26 |
-
from tqdm import tqdm
|
27 |
-
|
28 |
-
from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
|
29 |
-
from toolbox.torchaudio.models.simple_lstm_irm.configuration_simple_lstm_irm import SimpleLstmIRMConfig
|
30 |
-
from toolbox.torchaudio.models.simple_lstm_irm.modeling_simple_lstm_irm import SimpleLstmIRMPretrainedModel
|
31 |
-
|
32 |
-
|
33 |
-
def get_args():
|
34 |
-
parser = argparse.ArgumentParser()
|
35 |
-
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
36 |
-
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
37 |
-
|
38 |
-
parser.add_argument("--max_epochs", default=100, type=int)
|
39 |
-
|
40 |
-
parser.add_argument("--batch_size", default=64, type=int)
|
41 |
-
parser.add_argument("--learning_rate", default=1e-3, type=float)
|
42 |
-
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
43 |
-
parser.add_argument("--patience", default=5, type=int)
|
44 |
-
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
45 |
-
parser.add_argument("--seed", default=0, type=int)
|
46 |
-
|
47 |
-
parser.add_argument("--config_file", default="config.yaml", type=str)
|
48 |
-
|
49 |
-
args = parser.parse_args()
|
50 |
-
return args
|
51 |
-
|
52 |
-
|
53 |
-
def logging_config(file_dir: str):
|
54 |
-
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
55 |
-
|
56 |
-
logging.basicConfig(format=fmt,
|
57 |
-
datefmt="%m/%d/%Y %H:%M:%S",
|
58 |
-
level=logging.INFO)
|
59 |
-
file_handler = TimedRotatingFileHandler(
|
60 |
-
filename=os.path.join(file_dir, "main.log"),
|
61 |
-
encoding="utf-8",
|
62 |
-
when="D",
|
63 |
-
interval=1,
|
64 |
-
backupCount=7
|
65 |
-
)
|
66 |
-
file_handler.setLevel(logging.INFO)
|
67 |
-
file_handler.setFormatter(logging.Formatter(fmt))
|
68 |
-
logger = logging.getLogger(__name__)
|
69 |
-
logger.addHandler(file_handler)
|
70 |
-
|
71 |
-
return logger
|
72 |
-
|
73 |
-
|
74 |
-
class CollateFunction(object):
|
75 |
-
def __init__(self,
|
76 |
-
n_fft: int = 512,
|
77 |
-
win_length: int = 200,
|
78 |
-
hop_length: int = 80,
|
79 |
-
window_fn: str = "hamming",
|
80 |
-
irm_beta: float = 1.0,
|
81 |
-
epsilon: float = 1e-8,
|
82 |
-
):
|
83 |
-
self.n_fft = n_fft
|
84 |
-
self.win_length = win_length
|
85 |
-
self.hop_length = hop_length
|
86 |
-
self.window_fn = window_fn
|
87 |
-
self.irm_beta = irm_beta
|
88 |
-
self.epsilon = epsilon
|
89 |
-
|
90 |
-
self.transform = torchaudio.transforms.Spectrogram(
|
91 |
-
n_fft=self.n_fft,
|
92 |
-
win_length=self.win_length,
|
93 |
-
hop_length=self.hop_length,
|
94 |
-
power=2.0,
|
95 |
-
window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
|
96 |
-
)
|
97 |
-
|
98 |
-
def __call__(self, batch: List[dict]):
|
99 |
-
mix_spec_list = list()
|
100 |
-
speech_irm_list = list()
|
101 |
-
snr_db_list = list()
|
102 |
-
for sample in batch:
|
103 |
-
noise_wave: torch.Tensor = sample["noise_wave"]
|
104 |
-
speech_wave: torch.Tensor = sample["speech_wave"]
|
105 |
-
mix_wave: torch.Tensor = sample["mix_wave"]
|
106 |
-
snr_db: float = sample["snr_db"]
|
107 |
-
|
108 |
-
noise_spec = self.transform.forward(noise_wave)
|
109 |
-
speech_spec = self.transform.forward(speech_wave)
|
110 |
-
mix_spec = self.transform.forward(mix_wave)
|
111 |
-
|
112 |
-
# noise_irm = noise_spec / (noise_spec + speech_spec)
|
113 |
-
speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon)
|
114 |
-
speech_irm = torch.pow(speech_irm, self.irm_beta)
|
115 |
-
|
116 |
-
mix_spec_list.append(mix_spec)
|
117 |
-
speech_irm_list.append(speech_irm)
|
118 |
-
snr_db_list.append(torch.tensor(snr_db, dtype=torch.float32))
|
119 |
-
|
120 |
-
mix_spec_list = torch.stack(mix_spec_list)
|
121 |
-
speech_irm_list = torch.stack(speech_irm_list)
|
122 |
-
snr_db_list = torch.stack(snr_db_list) # shape: (batch_size,)
|
123 |
-
|
124 |
-
# assert
|
125 |
-
if torch.any(torch.isnan(mix_spec_list)):
|
126 |
-
raise AssertionError("nan in mix_spec Tensor")
|
127 |
-
if torch.any(torch.isnan(speech_irm_list)):
|
128 |
-
raise AssertionError("nan in speech_irm Tensor")
|
129 |
-
if torch.any(torch.isnan(snr_db_list)):
|
130 |
-
raise AssertionError("nan in snr_db Tensor")
|
131 |
-
|
132 |
-
return mix_spec_list, speech_irm_list, snr_db_list
|
133 |
-
|
134 |
-
|
135 |
-
collate_fn = CollateFunction()
|
136 |
-
|
137 |
-
|
138 |
-
def main():
|
139 |
-
args = get_args()
|
140 |
-
|
141 |
-
serialization_dir = Path(args.serialization_dir)
|
142 |
-
serialization_dir.mkdir(parents=True, exist_ok=True)
|
143 |
-
|
144 |
-
logger = logging_config(serialization_dir)
|
145 |
-
|
146 |
-
random.seed(args.seed)
|
147 |
-
np.random.seed(args.seed)
|
148 |
-
torch.manual_seed(args.seed)
|
149 |
-
logger.info("set seed: {}".format(args.seed))
|
150 |
-
|
151 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
152 |
-
n_gpu = torch.cuda.device_count()
|
153 |
-
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
|
154 |
-
|
155 |
-
# datasets
|
156 |
-
logger.info("prepare datasets")
|
157 |
-
train_dataset = DenoiseExcelDataset(
|
158 |
-
excel_file=args.train_dataset,
|
159 |
-
expected_sample_rate=8000,
|
160 |
-
max_wave_value=32768.0,
|
161 |
-
)
|
162 |
-
valid_dataset = DenoiseExcelDataset(
|
163 |
-
excel_file=args.valid_dataset,
|
164 |
-
expected_sample_rate=8000,
|
165 |
-
max_wave_value=32768.0,
|
166 |
-
)
|
167 |
-
train_data_loader = DataLoader(
|
168 |
-
dataset=train_dataset,
|
169 |
-
batch_size=args.batch_size,
|
170 |
-
shuffle=True,
|
171 |
-
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
172 |
-
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
173 |
-
collate_fn=collate_fn,
|
174 |
-
pin_memory=False,
|
175 |
-
# prefetch_factor=64,
|
176 |
-
)
|
177 |
-
valid_data_loader = DataLoader(
|
178 |
-
dataset=valid_dataset,
|
179 |
-
batch_size=args.batch_size,
|
180 |
-
shuffle=True,
|
181 |
-
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
182 |
-
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
183 |
-
collate_fn=collate_fn,
|
184 |
-
pin_memory=False,
|
185 |
-
# prefetch_factor=64,
|
186 |
-
)
|
187 |
-
|
188 |
-
# models
|
189 |
-
logger.info(f"prepare models. config_file: {args.config_file}")
|
190 |
-
config = SimpleLstmIRMConfig.from_pretrained(
|
191 |
-
pretrained_model_name_or_path=args.config_file,
|
192 |
-
# num_labels=vocabulary.get_vocab_size(namespace="labels")
|
193 |
-
)
|
194 |
-
model = SimpleLstmIRMPretrainedModel(
|
195 |
-
config=config,
|
196 |
-
)
|
197 |
-
model.to(device)
|
198 |
-
model.train()
|
199 |
-
|
200 |
-
# optimizer
|
201 |
-
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
|
202 |
-
param_optimizer = model.parameters()
|
203 |
-
optimizer = torch.optim.Adam(
|
204 |
-
param_optimizer,
|
205 |
-
lr=args.learning_rate,
|
206 |
-
)
|
207 |
-
# lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
208 |
-
# optimizer,
|
209 |
-
# step_size=2000
|
210 |
-
# )
|
211 |
-
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
212 |
-
optimizer,
|
213 |
-
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
214 |
-
)
|
215 |
-
mse_loss = nn.MSELoss(
|
216 |
-
reduction="mean",
|
217 |
-
)
|
218 |
-
|
219 |
-
# training loop
|
220 |
-
logger.info("training")
|
221 |
-
|
222 |
-
training_loss = 10000000000
|
223 |
-
evaluation_loss = 10000000000
|
224 |
-
|
225 |
-
model_list = list()
|
226 |
-
best_idx_epoch = None
|
227 |
-
best_metric = None
|
228 |
-
patience_count = 0
|
229 |
-
|
230 |
-
for idx_epoch in range(args.max_epochs):
|
231 |
-
total_loss = 0.
|
232 |
-
total_examples = 0.
|
233 |
-
progress_bar = tqdm(
|
234 |
-
total=len(train_data_loader),
|
235 |
-
desc="Training; epoch: {}".format(idx_epoch),
|
236 |
-
)
|
237 |
-
|
238 |
-
for batch in train_data_loader:
|
239 |
-
mix_spec, speech_irm, snr_db = batch
|
240 |
-
mix_spec = mix_spec.to(device)
|
241 |
-
speech_irm_target = speech_irm.to(device)
|
242 |
-
snr_db_target = snr_db.to(device)
|
243 |
-
|
244 |
-
speech_irm_prediction = model.forward(mix_spec)
|
245 |
-
loss = mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
246 |
-
|
247 |
-
total_loss += loss.item()
|
248 |
-
total_examples += mix_spec.size(0)
|
249 |
-
|
250 |
-
optimizer.zero_grad()
|
251 |
-
loss.backward()
|
252 |
-
optimizer.step()
|
253 |
-
lr_scheduler.step()
|
254 |
-
|
255 |
-
training_loss = total_loss / total_examples
|
256 |
-
training_loss = round(training_loss, 4)
|
257 |
-
|
258 |
-
progress_bar.update(1)
|
259 |
-
progress_bar.set_postfix({
|
260 |
-
"training_loss": training_loss,
|
261 |
-
})
|
262 |
-
|
263 |
-
total_loss = 0.
|
264 |
-
total_examples = 0.
|
265 |
-
progress_bar = tqdm(
|
266 |
-
total=len(valid_data_loader),
|
267 |
-
desc="Evaluation; epoch: {}".format(idx_epoch),
|
268 |
-
)
|
269 |
-
for batch in valid_data_loader:
|
270 |
-
mix_spec, speech_irm, snr_db = batch
|
271 |
-
mix_spec = mix_spec.to(device)
|
272 |
-
speech_irm_target = speech_irm.to(device)
|
273 |
-
snr_db_target = snr_db.to(device)
|
274 |
-
|
275 |
-
with torch.no_grad():
|
276 |
-
speech_irm_prediction = model.forward(mix_spec)
|
277 |
-
loss = mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
278 |
-
|
279 |
-
total_loss += loss.item()
|
280 |
-
total_examples += mix_spec.size(0)
|
281 |
-
|
282 |
-
evaluation_loss = total_loss / total_examples
|
283 |
-
evaluation_loss = round(evaluation_loss, 4)
|
284 |
-
|
285 |
-
progress_bar.update(1)
|
286 |
-
progress_bar.set_postfix({
|
287 |
-
"evaluation_loss": evaluation_loss,
|
288 |
-
})
|
289 |
-
|
290 |
-
# save path
|
291 |
-
epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
|
292 |
-
epoch_dir.mkdir(parents=True, exist_ok=False)
|
293 |
-
|
294 |
-
# save models
|
295 |
-
model.save_pretrained(epoch_dir.as_posix())
|
296 |
-
|
297 |
-
model_list.append(epoch_dir)
|
298 |
-
if len(model_list) >= args.num_serialized_models_to_keep:
|
299 |
-
model_to_delete: Path = model_list.pop(0)
|
300 |
-
shutil.rmtree(model_to_delete.as_posix())
|
301 |
-
|
302 |
-
# save metric
|
303 |
-
if best_metric is None:
|
304 |
-
best_idx_epoch = idx_epoch
|
305 |
-
best_metric = evaluation_loss
|
306 |
-
elif evaluation_loss < best_metric:
|
307 |
-
best_idx_epoch = idx_epoch
|
308 |
-
best_metric = evaluation_loss
|
309 |
-
else:
|
310 |
-
pass
|
311 |
-
|
312 |
-
metrics = {
|
313 |
-
"idx_epoch": idx_epoch,
|
314 |
-
"best_idx_epoch": best_idx_epoch,
|
315 |
-
"training_loss": training_loss,
|
316 |
-
"evaluation_loss": evaluation_loss,
|
317 |
-
"learning_rate": optimizer.param_groups[0]["lr"],
|
318 |
-
}
|
319 |
-
metrics_filename = epoch_dir / "metrics_epoch.json"
|
320 |
-
with open(metrics_filename, "w", encoding="utf-8") as f:
|
321 |
-
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
322 |
-
|
323 |
-
# save best
|
324 |
-
best_dir = serialization_dir / "best"
|
325 |
-
if best_idx_epoch == idx_epoch:
|
326 |
-
if best_dir.exists():
|
327 |
-
shutil.rmtree(best_dir)
|
328 |
-
shutil.copytree(epoch_dir, best_dir)
|
329 |
-
|
330 |
-
# early stop
|
331 |
-
early_stop_flag = False
|
332 |
-
if best_idx_epoch == idx_epoch:
|
333 |
-
patience_count = 0
|
334 |
-
else:
|
335 |
-
patience_count += 1
|
336 |
-
if patience_count >= args.patience:
|
337 |
-
early_stop_flag = True
|
338 |
-
|
339 |
-
# early stop
|
340 |
-
if early_stop_flag:
|
341 |
-
break
|
342 |
-
return
|
343 |
-
|
344 |
-
|
345 |
-
if __name__ == '__main__':
|
346 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/dfnet/configuration_dfnet.py
CHANGED
@@ -14,6 +14,8 @@ class DfNetConfig(PretrainedConfig):
|
|
14 |
win_type: str = "hann",
|
15 |
|
16 |
spec_bins: int = 256,
|
|
|
|
|
17 |
|
18 |
conv_channels: int = 64,
|
19 |
conv_kernel_size_input: Tuple[int, int] = (3, 3),
|
@@ -79,6 +81,8 @@ class DfNetConfig(PretrainedConfig):
|
|
79 |
|
80 |
# spectrum
|
81 |
self.spec_bins = spec_bins
|
|
|
|
|
82 |
|
83 |
# conv
|
84 |
self.conv_channels = conv_channels
|
|
|
14 |
win_type: str = "hann",
|
15 |
|
16 |
spec_bins: int = 256,
|
17 |
+
erb_bins: int = 32,
|
18 |
+
min_freq_bins_for_erb: int = 2,
|
19 |
|
20 |
conv_channels: int = 64,
|
21 |
conv_kernel_size_input: Tuple[int, int] = (3, 3),
|
|
|
81 |
|
82 |
# spectrum
|
83 |
self.spec_bins = spec_bins
|
84 |
+
self.erb_bins = erb_bins
|
85 |
+
self.min_freq_bins_for_erb = min_freq_bins_for_erb
|
86 |
|
87 |
# conv
|
88 |
self.conv_channels = conv_channels
|
toolbox/torchaudio/models/dfnet/conv_stft.py
DELETED
@@ -1,148 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
"""
|
4 |
-
https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/conv_stft.py
|
5 |
-
"""
|
6 |
-
import numpy as np
|
7 |
-
import torch
|
8 |
-
import torch.nn as nn
|
9 |
-
import torch.nn.functional as F
|
10 |
-
from scipy.signal import get_window
|
11 |
-
|
12 |
-
|
13 |
-
def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False):
|
14 |
-
if win_type == "None" or win_type is None:
|
15 |
-
window = np.ones(win_size)
|
16 |
-
else:
|
17 |
-
window = get_window(win_type, win_size, fftbins=True)**0.5
|
18 |
-
|
19 |
-
fourier_basis = np.fft.rfft(np.eye(nfft))[:win_size]
|
20 |
-
real_kernel = np.real(fourier_basis)
|
21 |
-
image_kernel = np.imag(fourier_basis)
|
22 |
-
kernel = np.concatenate([real_kernel, image_kernel], 1).T
|
23 |
-
|
24 |
-
if inverse:
|
25 |
-
kernel = np.linalg.pinv(kernel).T
|
26 |
-
|
27 |
-
kernel = kernel * window
|
28 |
-
kernel = kernel[:, None, :]
|
29 |
-
result = (
|
30 |
-
torch.from_numpy(kernel.astype(np.float32)),
|
31 |
-
torch.from_numpy(window[None, :, None].astype(np.float32))
|
32 |
-
)
|
33 |
-
return result
|
34 |
-
|
35 |
-
|
36 |
-
class ConvSTFT(nn.Module):
|
37 |
-
|
38 |
-
def __init__(self,
|
39 |
-
nfft: int,
|
40 |
-
win_size: int,
|
41 |
-
hop_size: int,
|
42 |
-
win_type: str = "hamming",
|
43 |
-
power: int = None,
|
44 |
-
requires_grad: bool = False):
|
45 |
-
super(ConvSTFT, self).__init__()
|
46 |
-
|
47 |
-
if nfft is None:
|
48 |
-
self.nfft = int(2**np.ceil(np.log2(win_size)))
|
49 |
-
else:
|
50 |
-
self.nfft = nfft
|
51 |
-
|
52 |
-
kernel, _ = init_kernels(self.nfft, win_size, hop_size, win_type)
|
53 |
-
self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
|
54 |
-
|
55 |
-
self.win_size = win_size
|
56 |
-
self.hop_size = hop_size
|
57 |
-
|
58 |
-
self.stride = hop_size
|
59 |
-
self.dim = self.nfft
|
60 |
-
self.power = power
|
61 |
-
|
62 |
-
def forward(self, inputs: torch.Tensor):
|
63 |
-
if inputs.dim() == 2:
|
64 |
-
inputs = torch.unsqueeze(inputs, 1)
|
65 |
-
|
66 |
-
matrix = F.conv1d(inputs, self.weight, stride=self.stride)
|
67 |
-
dim = self.dim // 2 + 1
|
68 |
-
real = matrix[:, :dim, :]
|
69 |
-
imag = matrix[:, dim:, :]
|
70 |
-
spec = torch.complex(real, imag)
|
71 |
-
|
72 |
-
if self.power is None:
|
73 |
-
return spec
|
74 |
-
elif self.power == 1:
|
75 |
-
mags = torch.sqrt(real**2 + imag**2)
|
76 |
-
# phase = torch.atan2(imag, real)
|
77 |
-
return mags
|
78 |
-
elif self.power == 2:
|
79 |
-
power = real**2 + imag**2
|
80 |
-
return power
|
81 |
-
else:
|
82 |
-
raise AssertionError
|
83 |
-
|
84 |
-
|
85 |
-
class ConviSTFT(nn.Module):
|
86 |
-
|
87 |
-
def __init__(self,
|
88 |
-
win_size: int,
|
89 |
-
hop_size: int,
|
90 |
-
nfft: int = None,
|
91 |
-
win_type: str = "hamming",
|
92 |
-
requires_grad: bool = False):
|
93 |
-
super(ConviSTFT, self).__init__()
|
94 |
-
if nfft is None:
|
95 |
-
self.nfft = int(2**np.ceil(np.log2(win_size)))
|
96 |
-
else:
|
97 |
-
self.nfft = nfft
|
98 |
-
|
99 |
-
kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True)
|
100 |
-
self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
|
101 |
-
|
102 |
-
self.win_size = win_size
|
103 |
-
self.hop_size = hop_size
|
104 |
-
self.win_type = win_type
|
105 |
-
|
106 |
-
self.stride = hop_size
|
107 |
-
self.dim = self.nfft
|
108 |
-
|
109 |
-
self.register_buffer("window", window)
|
110 |
-
self.register_buffer("enframe", torch.eye(win_size)[:, None, :])
|
111 |
-
|
112 |
-
def forward(self,
|
113 |
-
inputs: torch.Tensor):
|
114 |
-
"""
|
115 |
-
:param inputs: torch.Tensor, shape: [b, f, t]
|
116 |
-
:return:
|
117 |
-
"""
|
118 |
-
inputs = torch.view_as_real(inputs)
|
119 |
-
matrix = torch.concat(tensors=[inputs[..., 0], inputs[..., 1]], dim=1)
|
120 |
-
|
121 |
-
waveform = F.conv_transpose1d(matrix, self.weight, stride=self.stride)
|
122 |
-
|
123 |
-
# this is from torch-stft: https://github.com/pseeth/torch-stft
|
124 |
-
t = self.window.repeat(1, 1, matrix.size(-1))**2
|
125 |
-
coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
|
126 |
-
waveform = waveform / (coff + 1e-8)
|
127 |
-
return waveform
|
128 |
-
|
129 |
-
|
130 |
-
def main():
|
131 |
-
stft = ConvSTFT(nfft=512, win_size=512, hop_size=200, power=None)
|
132 |
-
istft = ConviSTFT(nfft=512, win_size=512, hop_size=200)
|
133 |
-
|
134 |
-
mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32)
|
135 |
-
|
136 |
-
spec = stft.forward(mixture)
|
137 |
-
# shape: [batch_size, freq_bins, time_steps]
|
138 |
-
print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
|
139 |
-
|
140 |
-
waveform = istft.forward(spec)
|
141 |
-
# shape: [batch_size, channels, num_samples]
|
142 |
-
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
|
143 |
-
|
144 |
-
return
|
145 |
-
|
146 |
-
|
147 |
-
if __name__ == "__main__":
|
148 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/dfnet/modeling_dfnet.py
CHANGED
@@ -12,8 +12,9 @@ import torchaudio
|
|
12 |
|
13 |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
14 |
from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
|
15 |
-
from toolbox.torchaudio.
|
16 |
from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
|
|
|
17 |
|
18 |
|
19 |
MODEL_FILE = "model.pt"
|
@@ -225,7 +226,8 @@ class GroupedLinear(nn.Module):
|
|
225 |
# The better way, but not supported by torchscript
|
226 |
# x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G]
|
227 |
x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G]
|
228 |
-
x = x.flatten(2, 3)
|
|
|
229 |
return x
|
230 |
|
231 |
def __repr__(self):
|
@@ -302,7 +304,8 @@ class SqueezedGRU_S(nn.Module):
|
|
302 |
self.linear_out = nn.Identity()
|
303 |
|
304 |
def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
305 |
-
|
|
|
306 |
|
307 |
x, h = self.gru.forward(x, h)
|
308 |
|
@@ -327,8 +330,8 @@ class Concat(nn.Module):
|
|
327 |
class Encoder(nn.Module):
|
328 |
def __init__(self, config: DfNetConfig):
|
329 |
super(Encoder, self).__init__()
|
330 |
-
self.embedding_input_size = config.conv_channels * config.
|
331 |
-
self.embedding_output_size = config.conv_channels * config.
|
332 |
self.embedding_hidden_size = config.embedding_hidden_size
|
333 |
|
334 |
self.spec_conv0 = CausalConv2d(
|
@@ -423,49 +426,55 @@ class Encoder(nn.Module):
|
|
423 |
self.lsnr_offset = config.min_local_snr
|
424 |
|
425 |
def forward(self,
|
426 |
-
|
427 |
feat_spec: torch.Tensor,
|
428 |
hidden_state: torch.Tensor = None,
|
429 |
):
|
430 |
-
#
|
431 |
-
e0 = self.spec_conv0.forward(
|
432 |
e1 = self.spec_conv1.forward(e0)
|
433 |
e2 = self.spec_conv2.forward(e1)
|
434 |
e3 = self.spec_conv3.forward(e2)
|
435 |
-
# e0 shape: [
|
436 |
-
# e1 shape: [
|
437 |
-
# e2 shape: [
|
438 |
-
# e3 shape: [
|
|
|
439 |
|
440 |
-
# feat_spec, shape: (
|
441 |
c0 = self.df_conv0(feat_spec)
|
442 |
c1 = self.df_conv1(c0)
|
443 |
-
# c0 shape: [
|
444 |
-
# c1 shape: [
|
|
|
445 |
|
446 |
cemb = c1.permute(0, 2, 3, 1)
|
447 |
-
# cemb shape: [
|
448 |
cemb = cemb.flatten(2)
|
449 |
-
# cemb shape: [
|
450 |
-
cemb =
|
451 |
-
|
|
|
|
|
452 |
|
453 |
-
# e3 shape: [
|
454 |
emb = e3.permute(0, 2, 3, 1)
|
455 |
-
# emb shape: [
|
456 |
emb = emb.flatten(2)
|
457 |
-
# emb shape: [
|
|
|
458 |
|
459 |
emb = self.combine(emb, cemb)
|
460 |
-
# if concat; emb shape: [
|
461 |
-
# if add; emb shape: [
|
462 |
|
463 |
emb, h = self.emb_gru.forward(emb, hidden_state)
|
464 |
-
|
465 |
-
#
|
|
|
466 |
|
467 |
lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
|
468 |
-
# lsnr shape: [
|
469 |
|
470 |
return e0, e1, e2, e3, emb, c0, lsnr, h
|
471 |
|
@@ -477,8 +486,8 @@ class Decoder(nn.Module):
|
|
477 |
if config.spec_bins % 8 != 0:
|
478 |
raise AssertionError("spec_bins should be divisible by 8")
|
479 |
|
480 |
-
self.emb_in_dim = config.conv_channels * config.
|
481 |
-
self.emb_out_dim = config.conv_channels * config.
|
482 |
self.emb_hidden_dim = config.decoder_emb_hidden_size
|
483 |
|
484 |
self.emb_gru = SqueezedGRU_S(
|
@@ -570,7 +579,7 @@ class Decoder(nn.Module):
|
|
570 |
b, _, t, f8 = e3.shape
|
571 |
|
572 |
# emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels]
|
573 |
-
emb, _ = self.emb_gru(emb)
|
574 |
# emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
|
575 |
emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2)
|
576 |
e3 = self.convt3(self.conv3p(e3) + emb)
|
@@ -588,7 +597,7 @@ class DfDecoder(nn.Module):
|
|
588 |
def __init__(self, config: DfNetConfig):
|
589 |
super(DfDecoder, self).__init__()
|
590 |
|
591 |
-
self.embedding_input_size = config.conv_channels * config.
|
592 |
self.df_decoder_hidden_size = config.df_decoder_hidden_size
|
593 |
self.df_num_layers = config.df_num_layers
|
594 |
|
@@ -712,14 +721,14 @@ class Mask(nn.Module):
|
|
712 |
return mask_pf
|
713 |
|
714 |
def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
715 |
-
# spec shape: [
|
716 |
|
717 |
if not self.training and self.use_post_filter:
|
718 |
mask = self.post_filter(mask)
|
719 |
|
720 |
-
# mask shape: [
|
721 |
mask = mask.unsqueeze(4)
|
722 |
-
# mask shape: [
|
723 |
return spec * mask
|
724 |
|
725 |
|
@@ -803,6 +812,13 @@ class DfNet(nn.Module):
|
|
803 |
self.hop_size = config.hop_size
|
804 |
self.win_type = config.win_type
|
805 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
806 |
self.stft = ConvSTFT(
|
807 |
nfft=config.nfft,
|
808 |
win_size=config.win_size,
|
@@ -867,37 +883,42 @@ class DfNet(nn.Module):
|
|
867 |
noisy, n_samples = self.signal_prepare(noisy)
|
868 |
|
869 |
# noisy shape: [b, num_samples_pad]
|
870 |
-
|
871 |
-
#
|
872 |
-
|
873 |
-
#
|
874 |
-
|
875 |
-
#
|
876 |
-
|
877 |
-
|
878 |
-
#
|
879 |
-
|
880 |
-
spec = torch.unsqueeze(
|
881 |
-
# spec shape: [b,
|
882 |
-
spec = spec.permute(0, 4, 3, 2, 1)
|
883 |
-
# spec shape: [b, 1, t, spec_bins, 2]
|
884 |
|
885 |
-
|
886 |
-
#
|
|
|
|
|
887 |
|
888 |
-
feat_spec =
|
889 |
-
# feat_spec shape: [b, 2, t,
|
890 |
feat_spec = feat_spec[..., :self.df_decoder.df_bins]
|
891 |
# feat_spec shape: [b, 2, t, df_bins]
|
892 |
|
893 |
-
e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(
|
894 |
|
895 |
mask = self.decoder.forward(emb, e3, e2, e1, e0)
|
896 |
-
# mask shape: [b, 1, t,
|
|
|
|
|
897 |
if torch.any(mask > 1) or torch.any(mask < 0):
|
898 |
raise AssertionError
|
899 |
|
900 |
spec_m = self.mask.forward(spec, mask)
|
|
|
|
|
|
|
901 |
|
902 |
# lsnr shape: [b, t, 1]
|
903 |
lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
|
@@ -907,8 +928,10 @@ class DfNet(nn.Module):
|
|
907 |
df_coefs = self.df_out_transform(df_coefs)
|
908 |
# df_coefs shape: [b, df_order, t, df_bins, 2]
|
909 |
|
910 |
-
|
911 |
-
#
|
|
|
|
|
912 |
|
913 |
spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
|
914 |
|
@@ -916,14 +939,10 @@ class DfNet(nn.Module):
|
|
916 |
spec_e = spec_e.permute(0, 2, 1, 3)
|
917 |
# spec_e shape: [b, spec_bins, t, 2]
|
918 |
|
919 |
-
mask = torch.squeeze(mask, dim=1)
|
920 |
-
mask = mask.permute(0, 2, 1)
|
921 |
-
# mask shape: [b, spec_bins, t]
|
922 |
-
est_mask = self.mask_transfer(mask)
|
923 |
-
# est_mask shape: [b, f, t]
|
924 |
-
|
925 |
# spec_e shape: [b, spec_bins, t, 2]
|
926 |
-
est_spec =
|
|
|
|
|
927 |
# est_spec shape: [b, f, t], torch.complex64
|
928 |
|
929 |
est_wav = self.istft.forward(est_spec)
|
@@ -931,33 +950,11 @@ class DfNet(nn.Module):
|
|
931 |
est_wav = est_wav[:, :n_samples]
|
932 |
# est_wav shape: [b, n_samples]
|
933 |
|
934 |
-
|
|
|
|
|
935 |
|
936 |
-
|
937 |
-
# spec_e shape: [b, spec_bins, t, 2]
|
938 |
-
b, _, t, _ = spec_e.shape
|
939 |
-
est_spec = torch.complex(
|
940 |
-
real=torch.concat(tensors=[
|
941 |
-
spec_e[..., 0],
|
942 |
-
torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
|
943 |
-
], dim=1),
|
944 |
-
imag=torch.concat(tensors=[
|
945 |
-
spec_e[..., 1],
|
946 |
-
torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
|
947 |
-
], dim=1),
|
948 |
-
)
|
949 |
-
# est_spec shape: [b, f, t]
|
950 |
-
return est_spec
|
951 |
-
|
952 |
-
def mask_transfer(self, mask: torch.Tensor) -> torch.Tensor:
|
953 |
-
# mask shape: [b, 256, t]
|
954 |
-
b, _, t = mask.shape
|
955 |
-
est_mask = torch.concat(tensors=[
|
956 |
-
mask,
|
957 |
-
torch.zeros(size=(b, 1, t), dtype=mask.dtype).to(mask.device)
|
958 |
-
], dim=1)
|
959 |
-
# est_mask shape: [b, 257, t]
|
960 |
-
return est_mask
|
961 |
|
962 |
def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
|
963 |
"""
|
|
|
12 |
|
13 |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
14 |
from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
|
15 |
+
from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT
|
16 |
from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
|
17 |
+
from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands
|
18 |
|
19 |
|
20 |
MODEL_FILE = "model.pt"
|
|
|
226 |
# The better way, but not supported by torchscript
|
227 |
# x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G]
|
228 |
x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G]
|
229 |
+
x = x.flatten(2, 3)
|
230 |
+
# x: [b, t, h]
|
231 |
return x
|
232 |
|
233 |
def __repr__(self):
|
|
|
304 |
self.linear_out = nn.Identity()
|
305 |
|
306 |
def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
307 |
+
# inputs: shape: [b, t, h]
|
308 |
+
x = self.linear_in.forward(inputs)
|
309 |
|
310 |
x, h = self.gru.forward(x, h)
|
311 |
|
|
|
330 |
class Encoder(nn.Module):
|
331 |
def __init__(self, config: DfNetConfig):
|
332 |
super(Encoder, self).__init__()
|
333 |
+
self.embedding_input_size = config.conv_channels * config.erb_bins // 4
|
334 |
+
self.embedding_output_size = config.conv_channels * config.erb_bins // 4
|
335 |
self.embedding_hidden_size = config.embedding_hidden_size
|
336 |
|
337 |
self.spec_conv0 = CausalConv2d(
|
|
|
426 |
self.lsnr_offset = config.min_local_snr
|
427 |
|
428 |
def forward(self,
|
429 |
+
feat_erb: torch.Tensor,
|
430 |
feat_spec: torch.Tensor,
|
431 |
hidden_state: torch.Tensor = None,
|
432 |
):
|
433 |
+
# feat_erb shape: (b, 1, t, erb_bins)
|
434 |
+
e0 = self.spec_conv0.forward(feat_erb)
|
435 |
e1 = self.spec_conv1.forward(e0)
|
436 |
e2 = self.spec_conv2.forward(e1)
|
437 |
e3 = self.spec_conv3.forward(e2)
|
438 |
+
# e0 shape: [b, c, t, erb_bins]
|
439 |
+
# e1 shape: [b, c, t, erb_bins // 2]
|
440 |
+
# e2 shape: [b, c, t, erb_bins // 4]
|
441 |
+
# e3 shape: [b, c, t, erb_bins // 4]
|
442 |
+
# e3 shape: [b, 64, t, 32/4=8]
|
443 |
|
444 |
+
# feat_spec, shape: (b, 2, t, df_bins)
|
445 |
c0 = self.df_conv0(feat_spec)
|
446 |
c1 = self.df_conv1(c0)
|
447 |
+
# c0 shape: [b, c, t, df_bins]
|
448 |
+
# c1 shape: [b, c, t, df_bins // 2]
|
449 |
+
# c1 shape: [b, 64, t, 96/2=48]
|
450 |
|
451 |
cemb = c1.permute(0, 2, 3, 1)
|
452 |
+
# cemb shape: [b, t, df_bins // 2, c]
|
453 |
cemb = cemb.flatten(2)
|
454 |
+
# cemb shape: [b, t, df_bins // 2 * c]
|
455 |
+
# cemb shape: [b, t, 96/2*64=3072]
|
456 |
+
cemb = self.df_fc_emb.forward(cemb)
|
457 |
+
# cemb shape: [b, t, erb_bins // 4 * c]
|
458 |
+
# cemb shape: [b, t, 32/4*64=512]
|
459 |
|
460 |
+
# e3 shape: [b, c, t, erb_bins // 4]
|
461 |
emb = e3.permute(0, 2, 3, 1)
|
462 |
+
# emb shape: [b, t, erb_bins // 4, c]
|
463 |
emb = emb.flatten(2)
|
464 |
+
# emb shape: [b, t, erb_bins // 4 * c]
|
465 |
+
# emb shape: [b, t, 32/4*64=512]
|
466 |
|
467 |
emb = self.combine(emb, cemb)
|
468 |
+
# if concat; emb shape: [b, t, spec_bins // 4 * c * 2]
|
469 |
+
# if add; emb shape: [b, t, spec_bins // 4 * c]
|
470 |
|
471 |
emb, h = self.emb_gru.forward(emb, hidden_state)
|
472 |
+
|
473 |
+
# emb shape: [b, t, spec_dim // 4 * c]
|
474 |
+
# h shape: [b, 1, spec_dim]
|
475 |
|
476 |
lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
|
477 |
+
# lsnr shape: [b, t, 1]
|
478 |
|
479 |
return e0, e1, e2, e3, emb, c0, lsnr, h
|
480 |
|
|
|
486 |
if config.spec_bins % 8 != 0:
|
487 |
raise AssertionError("spec_bins should be divisible by 8")
|
488 |
|
489 |
+
self.emb_in_dim = config.conv_channels * config.erb_bins // 4
|
490 |
+
self.emb_out_dim = config.conv_channels * config.erb_bins // 4
|
491 |
self.emb_hidden_dim = config.decoder_emb_hidden_size
|
492 |
|
493 |
self.emb_gru = SqueezedGRU_S(
|
|
|
579 |
b, _, t, f8 = e3.shape
|
580 |
|
581 |
# emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels]
|
582 |
+
emb, _ = self.emb_gru.forward(emb)
|
583 |
# emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
|
584 |
emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2)
|
585 |
e3 = self.convt3(self.conv3p(e3) + emb)
|
|
|
597 |
def __init__(self, config: DfNetConfig):
|
598 |
super(DfDecoder, self).__init__()
|
599 |
|
600 |
+
self.embedding_input_size = config.conv_channels * config.erb_bins // 4
|
601 |
self.df_decoder_hidden_size = config.df_decoder_hidden_size
|
602 |
self.df_num_layers = config.df_num_layers
|
603 |
|
|
|
721 |
return mask_pf
|
722 |
|
723 |
def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
724 |
+
# spec shape: [b, 1, t, spec_bins, 2]
|
725 |
|
726 |
if not self.training and self.use_post_filter:
|
727 |
mask = self.post_filter(mask)
|
728 |
|
729 |
+
# mask shape: [b, 1, t, spec_bins]
|
730 |
mask = mask.unsqueeze(4)
|
731 |
+
# mask shape: [b, 1, t, spec_bins, 1]
|
732 |
return spec * mask
|
733 |
|
734 |
|
|
|
812 |
self.hop_size = config.hop_size
|
813 |
self.win_type = config.win_type
|
814 |
|
815 |
+
self.erb_bands = ErbBands(
|
816 |
+
sample_rate=config.sample_rate,
|
817 |
+
nfft=config.nfft,
|
818 |
+
erb_bins=config.erb_bins,
|
819 |
+
min_freq_bins_for_erb=config.min_freq_bins_for_erb,
|
820 |
+
)
|
821 |
+
|
822 |
self.stft = ConvSTFT(
|
823 |
nfft=config.nfft,
|
824 |
win_size=config.win_size,
|
|
|
883 |
noisy, n_samples = self.signal_prepare(noisy)
|
884 |
|
885 |
# noisy shape: [b, num_samples_pad]
|
886 |
+
spec_cmp = self.stft.forward(noisy)
|
887 |
+
# spec_complex shape: [b, f, t], torch.complex64
|
888 |
+
spec_cmp = torch.transpose(spec_cmp, dim0=1, dim1=2)
|
889 |
+
# spec_complex shape: [b, t, f], torch.complex64
|
890 |
+
spec_cmp_real = torch.view_as_real(spec_cmp)
|
891 |
+
# spec_cmp_real shape: [b, t, f, 2]
|
892 |
+
spec_mag = torch.abs(spec_cmp)
|
893 |
+
spec_pow = torch.square(spec_mag)
|
894 |
+
# shape: [b, t, f]
|
895 |
+
|
896 |
+
spec = torch.unsqueeze(spec_cmp_real, dim=1)
|
897 |
+
# spec shape: [b, 1, t, f, 2]
|
|
|
|
|
898 |
|
899 |
+
feat_erb = self.erb_bands.erb_scale(spec_pow, db=True)
|
900 |
+
# feat_erb shape: [b, t, erb_bins]
|
901 |
+
feat_erb = torch.unsqueeze(feat_erb, dim=1)
|
902 |
+
# feat_erb shape: [b, 1, t, erb_bins]
|
903 |
|
904 |
+
feat_spec = spec_cmp_real.permute(0, 3, 1, 2)
|
905 |
+
# feat_spec shape: [b, 2, t, f]
|
906 |
feat_spec = feat_spec[..., :self.df_decoder.df_bins]
|
907 |
# feat_spec shape: [b, 2, t, df_bins]
|
908 |
|
909 |
+
e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_erb, feat_spec)
|
910 |
|
911 |
mask = self.decoder.forward(emb, e3, e2, e1, e0)
|
912 |
+
# mask shape: [b, 1, t, erb_bins]
|
913 |
+
mask = self.erb_bands.erb_scale_inv(mask)
|
914 |
+
# mask shape: [b, 1, t, f]
|
915 |
if torch.any(mask > 1) or torch.any(mask < 0):
|
916 |
raise AssertionError
|
917 |
|
918 |
spec_m = self.mask.forward(spec, mask)
|
919 |
+
# spec_m shape: [b, 1, t, f, 2]
|
920 |
+
spec_m = spec_m[:, :, :, :self.config.spec_bins, :]
|
921 |
+
# spec_m shape: [b, 1, t, spec_bins, 2]
|
922 |
|
923 |
# lsnr shape: [b, t, 1]
|
924 |
lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
|
|
|
928 |
df_coefs = self.df_out_transform(df_coefs)
|
929 |
# df_coefs shape: [b, df_order, t, df_bins, 2]
|
930 |
|
931 |
+
spec_ = spec[:, :, :, :self.config.spec_bins, :]
|
932 |
+
# spec shape: [b, 1, t, spec_bins, 2]
|
933 |
+
spec_e = self.df_op.forward(spec_, df_coefs)
|
934 |
+
# spec_e shape: [b, 1, t, spec_bins, 2]
|
935 |
|
936 |
spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
|
937 |
|
|
|
939 |
spec_e = spec_e.permute(0, 2, 1, 3)
|
940 |
# spec_e shape: [b, spec_bins, t, 2]
|
941 |
|
|
|
|
|
|
|
|
|
|
|
|
|
942 |
# spec_e shape: [b, spec_bins, t, 2]
|
943 |
+
est_spec = torch.complex(real=spec_e[..., 0], imag=spec_e[..., 1])
|
944 |
+
# est_spec shape: [b, spec_bins, t], torch.complex64
|
945 |
+
est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1)
|
946 |
# est_spec shape: [b, f, t], torch.complex64
|
947 |
|
948 |
est_wav = self.istft.forward(est_spec)
|
|
|
950 |
est_wav = est_wav[:, :n_samples]
|
951 |
# est_wav shape: [b, n_samples]
|
952 |
|
953 |
+
est_mask = torch.squeeze(mask, dim=1)
|
954 |
+
est_mask = est_mask.permute(0, 2, 1)
|
955 |
+
# est_mask shape: [b, f, t]
|
956 |
|
957 |
+
return est_spec, est_wav, est_mask, lsnr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
958 |
|
959 |
def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
|
960 |
"""
|
toolbox/torchaudio/models/frcrn/conv_stft.py
CHANGED
@@ -127,8 +127,8 @@ class ConviSTFT(nn.Module):
|
|
127 |
|
128 |
|
129 |
def main():
|
130 |
-
stft = ConvSTFT(win_size=512, hop_size=200, feature_type="complex")
|
131 |
-
istft = ConviSTFT(win_size=512, hop_size=200, feature_type="complex")
|
132 |
|
133 |
mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32)
|
134 |
|
|
|
127 |
|
128 |
|
129 |
def main():
|
130 |
+
stft = ConvSTFT(nfft=512, win_size=512, hop_size=200, feature_type="complex")
|
131 |
+
istft = ConviSTFT(nfft=512, win_size=512, hop_size=200, feature_type="complex")
|
132 |
|
133 |
mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32)
|
134 |
|
toolbox/torchaudio/models/{simple_lstm_irm → lstm}/__init__.py
RENAMED
File without changes
|
toolbox/torchaudio/models/lstm/configuration_lstm.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
from toolbox.torchaudio.configuration_utils import PretrainedConfig
|
4 |
+
|
5 |
+
|
6 |
+
class LstmConfig(PretrainedConfig):
|
7 |
+
def __init__(self,
|
8 |
+
sample_rate: int = 8000,
|
9 |
+
segment_size: int = 32000,
|
10 |
+
nfft: int = 512,
|
11 |
+
win_size: int = 512,
|
12 |
+
hop_size: int = 256,
|
13 |
+
win_type: str = "hann",
|
14 |
+
|
15 |
+
hidden_size: int = 1024,
|
16 |
+
num_layers: int = 2,
|
17 |
+
dropout: float = 0.2,
|
18 |
+
|
19 |
+
min_snr_db: float = -10,
|
20 |
+
max_snr_db: float = 20,
|
21 |
+
|
22 |
+
max_epochs: int = 100,
|
23 |
+
batch_size: int = 4,
|
24 |
+
num_workers: int = 4,
|
25 |
+
seed: int = 1234,
|
26 |
+
|
27 |
+
lr: float = 0.001,
|
28 |
+
lr_scheduler: str = "CosineAnnealingLR",
|
29 |
+
lr_scheduler_kwargs: dict = None,
|
30 |
+
|
31 |
+
weight_decay: float = 0.00001,
|
32 |
+
clip_grad_norm: float = 10.,
|
33 |
+
eval_steps: int = 25000,
|
34 |
+
|
35 |
+
**kwargs
|
36 |
+
):
|
37 |
+
super(LstmConfig, self).__init__(**kwargs)
|
38 |
+
self.sample_rate = sample_rate
|
39 |
+
self.segment_size = segment_size
|
40 |
+
self.nfft = nfft
|
41 |
+
self.win_size = win_size
|
42 |
+
self.hop_size = hop_size
|
43 |
+
self.win_type = win_type
|
44 |
+
|
45 |
+
self.hidden_size = hidden_size
|
46 |
+
self.num_layers = num_layers
|
47 |
+
self.dropout = dropout
|
48 |
+
|
49 |
+
self.min_snr_db = min_snr_db
|
50 |
+
self.max_snr_db = max_snr_db
|
51 |
+
|
52 |
+
self.max_epochs = max_epochs
|
53 |
+
self.batch_size = batch_size
|
54 |
+
self.num_workers = num_workers
|
55 |
+
self.seed = seed
|
56 |
+
|
57 |
+
self.lr = lr
|
58 |
+
self.lr_scheduler = lr_scheduler
|
59 |
+
self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
|
60 |
+
|
61 |
+
self.weight_decay = weight_decay
|
62 |
+
self.clip_grad_norm = clip_grad_norm
|
63 |
+
self.eval_steps = eval_steps
|
64 |
+
|
65 |
+
|
66 |
+
def main():
|
67 |
+
config = LstmConfig()
|
68 |
+
config.to_yaml_file("config.yaml")
|
69 |
+
return
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
main()
|
toolbox/torchaudio/models/lstm/modeling_lstm.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://github.com/haoxiangsnr/IRM-based-Speech-Enhancement-using-LSTM/blob/master/model/lstm_model.py
|
5 |
+
"""
|
6 |
+
import os
|
7 |
+
from typing import Optional, Union, Tuple
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
import torchaudio
|
13 |
+
|
14 |
+
from toolbox.torchaudio.models.lstm.configuration_lstm import LstmConfig
|
15 |
+
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
16 |
+
from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT
|
17 |
+
|
18 |
+
|
19 |
+
MODEL_FILE = "model.pt"
|
20 |
+
|
21 |
+
|
22 |
+
class Transpose(nn.Module):
|
23 |
+
def __init__(self, dim0: int, dim1: int):
|
24 |
+
super(Transpose, self).__init__()
|
25 |
+
self.dim0 = dim0
|
26 |
+
self.dim1 = dim1
|
27 |
+
|
28 |
+
def forward(self, inputs: torch.Tensor):
|
29 |
+
inputs = torch.transpose(inputs, dim0=self.dim0, dim1=self.dim1)
|
30 |
+
return inputs
|
31 |
+
|
32 |
+
|
33 |
+
class LstmModel(nn.Module):
|
34 |
+
def __init__(self,
|
35 |
+
nfft: int = 512,
|
36 |
+
win_size: int = 512,
|
37 |
+
hop_size: int = 256,
|
38 |
+
win_type: str = "hann",
|
39 |
+
hidden_size=1024,
|
40 |
+
num_layers: int = 2,
|
41 |
+
batch_first: bool = True,
|
42 |
+
dropout: float = 0.2,
|
43 |
+
):
|
44 |
+
super(LstmModel, self).__init__()
|
45 |
+
self.nfft = nfft
|
46 |
+
self.win_size = win_size
|
47 |
+
self.hop_size = hop_size
|
48 |
+
self.win_type = win_type
|
49 |
+
|
50 |
+
self.spec_bins = nfft // 2 + 1
|
51 |
+
self.hidden_size = hidden_size
|
52 |
+
|
53 |
+
self.eps = 1e-8
|
54 |
+
|
55 |
+
self.stft = ConvSTFT(
|
56 |
+
nfft=self.nfft,
|
57 |
+
win_size=self.win_size,
|
58 |
+
hop_size=self.hop_size,
|
59 |
+
win_type=self.win_type,
|
60 |
+
power=None,
|
61 |
+
requires_grad=False
|
62 |
+
)
|
63 |
+
self.istft = ConviSTFT(
|
64 |
+
nfft=self.nfft,
|
65 |
+
win_size=self.win_size,
|
66 |
+
hop_size=self.hop_size,
|
67 |
+
win_type=self.win_type,
|
68 |
+
requires_grad=False
|
69 |
+
)
|
70 |
+
|
71 |
+
self.lstm = nn.LSTM(input_size=self.spec_bins,
|
72 |
+
hidden_size=hidden_size,
|
73 |
+
num_layers=num_layers,
|
74 |
+
batch_first=batch_first,
|
75 |
+
dropout=dropout,
|
76 |
+
)
|
77 |
+
self.linear = nn.Linear(in_features=hidden_size, out_features=self.spec_bins)
|
78 |
+
self.activation = nn.Sigmoid()
|
79 |
+
|
80 |
+
def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor:
|
81 |
+
if signal.dim() == 2:
|
82 |
+
signal = torch.unsqueeze(signal, dim=1)
|
83 |
+
_, _, n_samples = signal.shape
|
84 |
+
remainder = (n_samples - self.win_size) % self.hop_size
|
85 |
+
if remainder > 0:
|
86 |
+
n_samples_pad = self.hop_size - remainder
|
87 |
+
signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
|
88 |
+
return signal, n_samples
|
89 |
+
|
90 |
+
def forward(self,
|
91 |
+
noisy: torch.Tensor,
|
92 |
+
h_state: Tuple[torch.Tensor, torch.Tensor] = None,
|
93 |
+
):
|
94 |
+
noisy, num_samples = self.signal_prepare(noisy)
|
95 |
+
batch_size, _, num_samples_pad = noisy.shape
|
96 |
+
# print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}")
|
97 |
+
|
98 |
+
mag_noisy, pha_noisy = self.mag_pha_stft(noisy)
|
99 |
+
# shape: (b, f, t)
|
100 |
+
# t = (num_samples - win_size) / hop_size + 1
|
101 |
+
|
102 |
+
mask, h_state = self.forward_chunk(mag_noisy, h_state)
|
103 |
+
# mask shape: (b, f, t)
|
104 |
+
|
105 |
+
stft_denoise = self.do_mask(mag_noisy, pha_noisy, mask)
|
106 |
+
denoise = self.istft.forward(stft_denoise)
|
107 |
+
# denoise shape: [b, 1, num_samples_pad]
|
108 |
+
|
109 |
+
denoise = denoise[:, :, :num_samples]
|
110 |
+
# denoise shape: [b, 1, num_samples]
|
111 |
+
return denoise, mask, h_state
|
112 |
+
|
113 |
+
def mag_pha_stft(self, noisy: torch.Tensor):
|
114 |
+
# noisy shape: [b, num_samples]
|
115 |
+
stft_noisy = self.stft.forward(noisy)
|
116 |
+
# stft_noisy shape: [b, f, t], torch.complex64
|
117 |
+
|
118 |
+
real = torch.real(stft_noisy)
|
119 |
+
imag = torch.imag(stft_noisy)
|
120 |
+
mag_noisy = torch.sqrt(real ** 2 + imag ** 2)
|
121 |
+
pha_noisy = torch.atan2(imag, real)
|
122 |
+
# shape: (b, f, t)
|
123 |
+
return mag_noisy, pha_noisy
|
124 |
+
|
125 |
+
def forward_chunk(self,
|
126 |
+
mag_noisy: torch.Tensor,
|
127 |
+
h_state: Tuple[torch.Tensor, torch.Tensor] = None,
|
128 |
+
):
|
129 |
+
# mag_noisy shape: (b, f, t)
|
130 |
+
x = torch.transpose(mag_noisy, dim0=2, dim1=1)
|
131 |
+
# x shape: (b, t, f)
|
132 |
+
x, h_state = self.lstm.forward(x, hx=h_state)
|
133 |
+
x = self.linear.forward(x)
|
134 |
+
mask = self.activation(x)
|
135 |
+
# mask shape: (b, t, f)
|
136 |
+
mask = torch.transpose(mask, dim0=2, dim1=1)
|
137 |
+
# mask shape: (b, f, t)
|
138 |
+
return mask, h_state
|
139 |
+
|
140 |
+
def do_mask(self,
|
141 |
+
mag_noisy: torch.Tensor,
|
142 |
+
pha_noisy: torch.Tensor,
|
143 |
+
mask: torch.Tensor,
|
144 |
+
):
|
145 |
+
# (b, f, t)
|
146 |
+
mag_denoise = mag_noisy * mask
|
147 |
+
stft_denoise = mag_denoise * torch.exp((1j * pha_noisy))
|
148 |
+
return stft_denoise
|
149 |
+
|
150 |
+
|
151 |
+
class LstmPretrainedModel(LstmModel):
|
152 |
+
def __init__(self,
|
153 |
+
config: LstmConfig,
|
154 |
+
):
|
155 |
+
super(LstmPretrainedModel, self).__init__(
|
156 |
+
nfft=config.nfft,
|
157 |
+
win_size=config.win_size,
|
158 |
+
hop_size=config.hop_size,
|
159 |
+
win_type=config.win_type,
|
160 |
+
hidden_size=config.hidden_size,
|
161 |
+
num_layers=config.num_layers,
|
162 |
+
dropout=config.dropout,
|
163 |
+
)
|
164 |
+
self.config = config
|
165 |
+
|
166 |
+
@classmethod
|
167 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
168 |
+
config = LstmConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
169 |
+
|
170 |
+
model = cls(config)
|
171 |
+
|
172 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
173 |
+
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
|
174 |
+
else:
|
175 |
+
ckpt_file = pretrained_model_name_or_path
|
176 |
+
|
177 |
+
with open(ckpt_file, "rb") as f:
|
178 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
179 |
+
model.load_state_dict(state_dict, strict=True)
|
180 |
+
return model
|
181 |
+
|
182 |
+
def save_pretrained(self,
|
183 |
+
save_directory: Union[str, os.PathLike],
|
184 |
+
state_dict: Optional[dict] = None,
|
185 |
+
):
|
186 |
+
|
187 |
+
model = self
|
188 |
+
|
189 |
+
if state_dict is None:
|
190 |
+
state_dict = model.state_dict()
|
191 |
+
|
192 |
+
os.makedirs(save_directory, exist_ok=True)
|
193 |
+
|
194 |
+
# save state dict
|
195 |
+
model_file = os.path.join(save_directory, MODEL_FILE)
|
196 |
+
torch.save(state_dict, model_file)
|
197 |
+
|
198 |
+
# save config
|
199 |
+
config_file = os.path.join(save_directory, CONFIG_FILE)
|
200 |
+
self.config.to_yaml_file(config_file)
|
201 |
+
return save_directory
|
202 |
+
|
203 |
+
|
204 |
+
def main():
|
205 |
+
config = LstmConfig()
|
206 |
+
model = LstmPretrainedModel(config)
|
207 |
+
model.eval()
|
208 |
+
|
209 |
+
noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
|
210 |
+
noisy, _ = model.signal_prepare(noisy)
|
211 |
+
b, _, num_samples = noisy.shape
|
212 |
+
t = (num_samples - config.win_size) / config.hop_size + 1
|
213 |
+
|
214 |
+
waveform, mask, h_state = model.forward(noisy)
|
215 |
+
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
|
216 |
+
print(waveform[:, :, 300: 302])
|
217 |
+
|
218 |
+
# noisy_pad shape: [b, 1, num_samples_pad]
|
219 |
+
|
220 |
+
h_state = None
|
221 |
+
sub_spec_list = list()
|
222 |
+
for i in range(int(t)):
|
223 |
+
begin = i * config.hop_size
|
224 |
+
end = begin + config.win_size
|
225 |
+
sub_noisy = noisy[:, :, begin:end]
|
226 |
+
mag_noisy, pha_noisy = model.mag_pha_stft(sub_noisy)
|
227 |
+
mask, h_state = model.forward_chunk(mag_noisy, h_state)
|
228 |
+
sub_spec = model.do_mask(mag_noisy, pha_noisy, mask)
|
229 |
+
sub_spec_list.append(sub_spec)
|
230 |
+
|
231 |
+
spec = torch.concat(sub_spec_list, dim=2)
|
232 |
+
|
233 |
+
# 1
|
234 |
+
waveform = model.istft.forward(spec)
|
235 |
+
waveform = waveform[:, :, :num_samples]
|
236 |
+
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
|
237 |
+
print(waveform[:, :, 300: 302])
|
238 |
+
|
239 |
+
# 2
|
240 |
+
waveform_cache = None
|
241 |
+
coff_cache = None
|
242 |
+
waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32)
|
243 |
+
for i in range(int(t)):
|
244 |
+
sub_spec = spec[:, :, i:i+1]
|
245 |
+
begin = i * config.hop_size
|
246 |
+
end = begin + config.win_size - config.hop_size
|
247 |
+
sub_waveform, waveform_cache, coff_cache = model.istft.forward_chunk(sub_spec, waveform_cache, coff_cache)
|
248 |
+
# end = begin + config.win_size
|
249 |
+
# sub_waveform = model.istft.forward(sub_spec)
|
250 |
+
|
251 |
+
# (b, 1, win_size)
|
252 |
+
waveform[:, :, begin:end] = sub_waveform
|
253 |
+
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
|
254 |
+
print(waveform[:, :, 300: 302])
|
255 |
+
|
256 |
+
return
|
257 |
+
|
258 |
+
|
259 |
+
if __name__ == "__main__":
|
260 |
+
main()
|
toolbox/torchaudio/models/lstm/yaml/config.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "lstm"
|
2 |
+
|
3 |
+
# spec
|
4 |
+
sample_rate: 8000
|
5 |
+
segment_size: 32000
|
6 |
+
n_fft: 320
|
7 |
+
win_size: 320
|
8 |
+
hop_size: 160
|
9 |
+
win_type: hann
|
10 |
+
|
11 |
+
# data
|
12 |
+
max_snr_db: 20
|
13 |
+
min_snr_db: -10
|
14 |
+
|
15 |
+
# model
|
16 |
+
hidden_size: 512
|
17 |
+
num_layers: 3
|
18 |
+
dropout: 0.1
|
19 |
+
|
20 |
+
# train
|
21 |
+
max_epochs: 100
|
22 |
+
batch_size: 32
|
23 |
+
num_workers: 4
|
24 |
+
seed: 1234
|
25 |
+
|
26 |
+
lr: 0.001
|
27 |
+
lr_scheduler: CosineAnnealingLR
|
28 |
+
lr_scheduler_kwargs: {}
|
29 |
+
|
30 |
+
weight_decay: 0.00001
|
31 |
+
clip_grad_norm: 10.0
|
32 |
+
eval_steps: 25000
|
toolbox/torchaudio/models/simple_lstm_irm/configuration_simple_lstm_irm.py
DELETED
@@ -1,38 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
from toolbox.torchaudio.configuration_utils import PretrainedConfig
|
4 |
-
|
5 |
-
|
6 |
-
class SimpleLstmIRMConfig(PretrainedConfig):
|
7 |
-
def __init__(self,
|
8 |
-
sample_rate: int,
|
9 |
-
n_fft: int,
|
10 |
-
win_length: int,
|
11 |
-
hop_length: int,
|
12 |
-
|
13 |
-
num_bins: int,
|
14 |
-
hidden_size: int,
|
15 |
-
num_layers: int,
|
16 |
-
batch_first: bool,
|
17 |
-
dropout: float,
|
18 |
-
lookback: int,
|
19 |
-
lookahead: int,
|
20 |
-
**kwargs
|
21 |
-
):
|
22 |
-
super(SimpleLstmIRMConfig, self).__init__(**kwargs)
|
23 |
-
self.sample_rate = sample_rate
|
24 |
-
self.n_fft = n_fft
|
25 |
-
self.win_length = win_length
|
26 |
-
self.hop_length = hop_length
|
27 |
-
|
28 |
-
self.num_bins = num_bins
|
29 |
-
self.hidden_size = hidden_size
|
30 |
-
self.num_layers = num_layers
|
31 |
-
self.batch_first = batch_first
|
32 |
-
self.dropout = dropout
|
33 |
-
self.lookback = lookback
|
34 |
-
self.lookahead = lookahead
|
35 |
-
|
36 |
-
|
37 |
-
if __name__ == "__main__":
|
38 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/simple_lstm_irm/modeling_simple_lstm_irm.py
DELETED
@@ -1,133 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
"""
|
4 |
-
https://github.com/haoxiangsnr/IRM-based-Speech-Enhancement-using-LSTM/blob/master/model/lstm_model.py
|
5 |
-
"""
|
6 |
-
import os
|
7 |
-
from typing import Optional, Union
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
import torchaudio
|
12 |
-
|
13 |
-
from toolbox.torchaudio.models.simple_lstm_irm.configuration_simple_lstm_irm import SimpleLstmIRMConfig
|
14 |
-
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
15 |
-
|
16 |
-
|
17 |
-
MODEL_FILE = "model.pt"
|
18 |
-
|
19 |
-
|
20 |
-
class Transpose(nn.Module):
|
21 |
-
def __init__(self, dim0: int, dim1: int):
|
22 |
-
super(Transpose, self).__init__()
|
23 |
-
self.dim0 = dim0
|
24 |
-
self.dim1 = dim1
|
25 |
-
|
26 |
-
def forward(self, inputs: torch.Tensor):
|
27 |
-
inputs = torch.transpose(inputs, dim0=self.dim0, dim1=self.dim1)
|
28 |
-
return inputs
|
29 |
-
|
30 |
-
|
31 |
-
class SimpleLstmIRM(nn.Module):
|
32 |
-
"""
|
33 |
-
Ideal ratio mask estimator:
|
34 |
-
|
35 |
-
"""
|
36 |
-
|
37 |
-
def __init__(self, num_bins=257, hidden_size=1024,
|
38 |
-
num_layers: int = 2,
|
39 |
-
batch_first: bool = True,
|
40 |
-
dropout: float = 0.4,
|
41 |
-
):
|
42 |
-
super(SimpleLstmIRM, self).__init__()
|
43 |
-
self.num_bins = num_bins
|
44 |
-
self.hidden_size = hidden_size
|
45 |
-
|
46 |
-
self.lstm = nn.LSTM(input_size=num_bins,
|
47 |
-
hidden_size=hidden_size,
|
48 |
-
num_layers=num_layers,
|
49 |
-
batch_first=batch_first,
|
50 |
-
dropout=dropout,
|
51 |
-
)
|
52 |
-
self.linear = nn.Linear(in_features=hidden_size, out_features=num_bins)
|
53 |
-
self.activation = nn.Sigmoid()
|
54 |
-
|
55 |
-
def forward(self, spec: torch.Tensor):
|
56 |
-
# spec shape: (batch_size, num_bins, time_steps)
|
57 |
-
spec = torch.transpose(spec, dim0=2, dim1=1)
|
58 |
-
# frame_spec shape: (batch_size, time_steps, num_bins)
|
59 |
-
spec, _ = self.lstm(spec)
|
60 |
-
spec = self.linear(spec)
|
61 |
-
mask = self.activation(spec)
|
62 |
-
return mask
|
63 |
-
|
64 |
-
|
65 |
-
class SimpleLstmIRMPretrainedModel(SimpleLstmIRM):
|
66 |
-
def __init__(self,
|
67 |
-
config: SimpleLstmIRMConfig,
|
68 |
-
):
|
69 |
-
super(SimpleLstmIRMPretrainedModel, self).__init__(
|
70 |
-
num_bins=config.num_bins,
|
71 |
-
hidden_size=config.hidden_size,
|
72 |
-
)
|
73 |
-
self.config = config
|
74 |
-
|
75 |
-
@classmethod
|
76 |
-
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
77 |
-
config = SimpleLstmIRMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
78 |
-
|
79 |
-
model = cls(config)
|
80 |
-
|
81 |
-
if os.path.isdir(pretrained_model_name_or_path):
|
82 |
-
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
|
83 |
-
else:
|
84 |
-
ckpt_file = pretrained_model_name_or_path
|
85 |
-
|
86 |
-
with open(ckpt_file, "rb") as f:
|
87 |
-
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
88 |
-
model.load_state_dict(state_dict, strict=True)
|
89 |
-
return model
|
90 |
-
|
91 |
-
def save_pretrained(self,
|
92 |
-
save_directory: Union[str, os.PathLike],
|
93 |
-
state_dict: Optional[dict] = None,
|
94 |
-
):
|
95 |
-
|
96 |
-
model = self
|
97 |
-
|
98 |
-
if state_dict is None:
|
99 |
-
state_dict = model.state_dict()
|
100 |
-
|
101 |
-
os.makedirs(save_directory, exist_ok=True)
|
102 |
-
|
103 |
-
# save state dict
|
104 |
-
model_file = os.path.join(save_directory, MODEL_FILE)
|
105 |
-
torch.save(state_dict, model_file)
|
106 |
-
|
107 |
-
# save config
|
108 |
-
config_file = os.path.join(save_directory, CONFIG_FILE)
|
109 |
-
self.config.to_yaml_file(config_file)
|
110 |
-
return save_directory
|
111 |
-
|
112 |
-
|
113 |
-
def main():
|
114 |
-
transformer = torchaudio.transforms.Spectrogram(
|
115 |
-
n_fft=512,
|
116 |
-
win_length=200,
|
117 |
-
hop_length=80,
|
118 |
-
window_fn=torch.hamming_window,
|
119 |
-
)
|
120 |
-
|
121 |
-
model = SimpleLstmIRM()
|
122 |
-
|
123 |
-
inputs = torch.randn(size=(1, 1600), dtype=torch.float32)
|
124 |
-
spec = transformer.forward(inputs)
|
125 |
-
|
126 |
-
output = model.forward(spec)
|
127 |
-
print(output.shape)
|
128 |
-
print(output)
|
129 |
-
return
|
130 |
-
|
131 |
-
|
132 |
-
if __name__ == '__main__':
|
133 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/simple_lstm_irm/yaml/config.yaml
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
model_name: "simple_lstm_irm"
|
2 |
-
|
3 |
-
# spec
|
4 |
-
sample_rate: 8000
|
5 |
-
n_fft: 320
|
6 |
-
win_length: 320
|
7 |
-
hop_length: 80
|
8 |
-
|
9 |
-
# model
|
10 |
-
num_bins: 161
|
11 |
-
hidden_size: 512
|
12 |
-
num_layers: 3
|
13 |
-
batch_first: true
|
14 |
-
dropout: 0.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/modules/conv_stft.py
CHANGED
@@ -59,11 +59,11 @@ class ConvSTFT(nn.Module):
|
|
59 |
self.dim = self.nfft
|
60 |
self.power = power
|
61 |
|
62 |
-
def forward(self,
|
63 |
-
if
|
64 |
-
|
65 |
|
66 |
-
matrix = F.conv1d(
|
67 |
dim = self.dim // 2 + 1
|
68 |
real = matrix[:, :dim, :]
|
69 |
imag = matrix[:, dim:, :]
|
@@ -99,6 +99,8 @@ class ConviSTFT(nn.Module):
|
|
99 |
|
100 |
kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True)
|
101 |
self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
|
|
|
|
|
102 |
|
103 |
self.win_size = win_size
|
104 |
self.hop_size = hop_size
|
@@ -109,41 +111,158 @@ class ConviSTFT(nn.Module):
|
|
109 |
|
110 |
self.register_buffer("window", window)
|
111 |
self.register_buffer("enframe", torch.eye(win_size)[:, None, :])
|
|
|
|
|
112 |
|
113 |
def forward(self,
|
114 |
-
|
115 |
"""
|
116 |
-
|
|
|
|
|
|
|
|
|
117 |
:return:
|
118 |
"""
|
119 |
-
|
120 |
-
|
|
|
|
|
121 |
|
122 |
waveform = F.conv_transpose1d(matrix, self.weight, stride=self.stride)
|
|
|
123 |
|
124 |
# this is from torch-stft: https://github.com/pseeth/torch-stft
|
125 |
t = self.window.repeat(1, 1, matrix.size(-1))**2
|
|
|
126 |
coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
|
|
|
127 |
waveform = waveform / (coff + 1e-8)
|
|
|
128 |
return waveform
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
def main():
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
spec = stft.forward(mixture)
|
138 |
-
|
|
|
|
|
|
|
139 |
print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
|
140 |
|
141 |
waveform = istft.forward(spec)
|
142 |
# shape: [batch_size, channels, num_samples]
|
143 |
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
return
|
146 |
|
147 |
|
148 |
if __name__ == "__main__":
|
149 |
-
|
|
|
59 |
self.dim = self.nfft
|
60 |
self.power = power
|
61 |
|
62 |
+
def forward(self, waveform: torch.Tensor):
|
63 |
+
if waveform.dim() == 2:
|
64 |
+
waveform = torch.unsqueeze(waveform, 1)
|
65 |
|
66 |
+
matrix = F.conv1d(waveform, self.weight, stride=self.stride)
|
67 |
dim = self.dim // 2 + 1
|
68 |
real = matrix[:, :dim, :]
|
69 |
imag = matrix[:, dim:, :]
|
|
|
99 |
|
100 |
kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True)
|
101 |
self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
|
102 |
+
# weight shape: [f*2, 1, nfft]
|
103 |
+
# f = nfft // 2 + 1
|
104 |
|
105 |
self.win_size = win_size
|
106 |
self.hop_size = hop_size
|
|
|
111 |
|
112 |
self.register_buffer("window", window)
|
113 |
self.register_buffer("enframe", torch.eye(win_size)[:, None, :])
|
114 |
+
# window shape: [1, nfft, 1]
|
115 |
+
# enframe shape: [nfft, 1, nfft]
|
116 |
|
117 |
def forward(self,
|
118 |
+
spec: torch.Tensor):
|
119 |
"""
|
120 |
+
self.weight shape: [f*2, 1, win_size]
|
121 |
+
self.window shape: [1, win_size, 1]
|
122 |
+
self.enframe shape: [win_size, 1, win_size]
|
123 |
+
|
124 |
+
:param spec: torch.Tensor, shape: [b, f, t, 2]
|
125 |
:return:
|
126 |
"""
|
127 |
+
spec = torch.view_as_real(spec)
|
128 |
+
# spec shape: [b, f, t, 2]
|
129 |
+
matrix = torch.concat(tensors=[spec[..., 0], spec[..., 1]], dim=1)
|
130 |
+
# matrix shape: [b, f*2, t]
|
131 |
|
132 |
waveform = F.conv_transpose1d(matrix, self.weight, stride=self.stride)
|
133 |
+
# waveform shape: [b, 1, num_samples]
|
134 |
|
135 |
# this is from torch-stft: https://github.com/pseeth/torch-stft
|
136 |
t = self.window.repeat(1, 1, matrix.size(-1))**2
|
137 |
+
# t shape: [1, win_size, t]
|
138 |
coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
|
139 |
+
# coff shape: [1, 1, num_samples]
|
140 |
waveform = waveform / (coff + 1e-8)
|
141 |
+
# waveform = waveform / coff
|
142 |
return waveform
|
143 |
|
144 |
+
def forward_chunk(self,
|
145 |
+
spec: torch.Tensor,
|
146 |
+
waveform_cache: torch.Tensor = None,
|
147 |
+
coff_cache: torch.Tensor = None,
|
148 |
+
):
|
149 |
+
"""
|
150 |
+
:param spec: shape: [b, f, t]
|
151 |
+
:param waveform_cache: shape: [b, 1, win_size - hop_size]
|
152 |
+
:param coff_cache: shape: [b, 1, win_size - hop_size]
|
153 |
+
:return:
|
154 |
+
"""
|
155 |
+
spec = torch.view_as_real(spec)
|
156 |
+
matrix = torch.concat(tensors=[spec[..., 0], spec[..., 1]], dim=1)
|
157 |
+
|
158 |
+
waveform_current = F.conv_transpose1d(matrix, self.weight, stride=self.stride)
|
159 |
+
|
160 |
+
t = self.window.repeat(1, 1, matrix.size(-1))**2
|
161 |
+
coff_current = F.conv_transpose1d(t, self.enframe, stride=self.stride)
|
162 |
+
|
163 |
+
overlap_size = self.win_size - self.hop_size
|
164 |
+
|
165 |
+
if waveform_cache is not None:
|
166 |
+
waveform_overlap = waveform_current[:, :, :overlap_size] + waveform_cache
|
167 |
+
waveform_non_overlap = waveform_current[:, :, overlap_size:-self.hop_size]
|
168 |
+
waveform_output = torch.cat(tensors=[waveform_overlap, waveform_non_overlap], dim=-1)
|
169 |
+
new_waveform_cache = waveform_current[:, :, -self.hop_size:]
|
170 |
+
else:
|
171 |
+
waveform_output = waveform_current[:, :, :-self.hop_size]
|
172 |
+
new_waveform_cache = waveform_current[:, :, -self.hop_size:]
|
173 |
+
|
174 |
+
if coff_cache is not None:
|
175 |
+
coff_overlap = coff_current[:, :, :overlap_size] + coff_cache
|
176 |
+
coff_non_overlap = coff_current[:, :, overlap_size:-self.hop_size]
|
177 |
+
coff_output = torch.cat(tensors=[coff_overlap, coff_non_overlap], dim=-1)
|
178 |
+
new_coff_cache = coff_current[:, :, -self.hop_size:]
|
179 |
+
else:
|
180 |
+
coff_output = coff_current[:, :, :-self.hop_size]
|
181 |
+
new_coff_cache = coff_current[:, :, -self.hop_size:]
|
182 |
+
|
183 |
+
waveform_output = waveform_output / (coff_output + 1e-8)
|
184 |
+
return waveform_output, new_waveform_cache, new_coff_cache
|
185 |
+
|
186 |
|
187 |
def main():
|
188 |
+
nfft = 512
|
189 |
+
win_size = 512
|
190 |
+
hop_size = 256
|
191 |
+
|
192 |
+
stft = ConvSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size, power=None)
|
193 |
+
istft = ConviSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size)
|
194 |
+
|
195 |
+
mixture = torch.rand(size=(1, 16000), dtype=torch.float32)
|
196 |
+
b, num_samples = mixture.shape
|
197 |
+
t = (num_samples - win_size) / hop_size + 1
|
198 |
+
|
199 |
+
spec = stft.forward(mixture)
|
200 |
+
b, f, t = spec.shape
|
201 |
+
|
202 |
+
# 如果 spec 是由 stft 变换得来的,以下两种 waveform 还���方法就是一致的,否则还原出的 waveform 会有差异。
|
203 |
+
# spec = spec + 0.01 * torch.randn(size=(1, nfft//2+1, t), dtype=torch.float32)
|
204 |
+
print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
|
205 |
+
|
206 |
+
waveform = istft.forward(spec)
|
207 |
+
# shape: [batch_size, channels, num_samples]
|
208 |
+
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
|
209 |
+
print(waveform[:, :, 300: 302])
|
210 |
+
|
211 |
+
waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32)
|
212 |
+
for i in range(int(t)):
|
213 |
+
begin = i * hop_size
|
214 |
+
end = begin + win_size
|
215 |
+
sub_spec = spec[:, :, i:i+1]
|
216 |
+
sub_waveform = istft.forward(sub_spec)
|
217 |
+
# (b, 1, win_size)
|
218 |
+
waveform[:, :, begin:end] = sub_waveform
|
219 |
+
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
|
220 |
+
print(waveform[:, :, 300: 302])
|
221 |
+
|
222 |
+
return
|
223 |
|
224 |
+
|
225 |
+
def main2():
|
226 |
+
nfft = 512
|
227 |
+
win_size = 512
|
228 |
+
hop_size = 256
|
229 |
+
|
230 |
+
stft = ConvSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size, power=None)
|
231 |
+
istft = ConviSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size)
|
232 |
+
|
233 |
+
mixture = torch.rand(size=(1, 16128), dtype=torch.float32)
|
234 |
+
b, num_samples = mixture.shape
|
235 |
|
236 |
spec = stft.forward(mixture)
|
237 |
+
b, f, t = spec.shape
|
238 |
+
|
239 |
+
# 如果 spec 是由 stft 变换得来的,以下两种 waveform 还原方法就是一致的,否则还原出的 waveform 会有差异。
|
240 |
+
spec = spec + 0.01 * torch.randn(size=(1, nfft//2+1, t), dtype=torch.float32)
|
241 |
print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
|
242 |
|
243 |
waveform = istft.forward(spec)
|
244 |
# shape: [batch_size, channels, num_samples]
|
245 |
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
|
246 |
+
print(waveform[:, :, 300: 302])
|
247 |
+
|
248 |
+
waveform_cache = None
|
249 |
+
coff_cache = None
|
250 |
+
waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32)
|
251 |
+
for i in range(int(t)):
|
252 |
+
sub_spec = spec[:, :, i:i+1]
|
253 |
+
begin = i * hop_size
|
254 |
+
|
255 |
+
end = begin + win_size - hop_size
|
256 |
+
sub_waveform, waveform_cache, coff_cache = istft.forward_chunk(sub_spec, waveform_cache, coff_cache)
|
257 |
+
# end = begin + win_size
|
258 |
+
# sub_waveform = istft.forward(sub_spec)
|
259 |
+
|
260 |
+
waveform[:, :, begin:end] = sub_waveform
|
261 |
+
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
|
262 |
+
print(waveform[:, :, 300: 302])
|
263 |
|
264 |
return
|
265 |
|
266 |
|
267 |
if __name__ == "__main__":
|
268 |
+
main2()
|
toolbox/torchaudio/modules/freq_bands/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/modules/freq_bands/erb_bands.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import math
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
class ErbBandsNumpy(object):
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def freq2erb(freq_hz: float) -> float:
|
14 |
+
"""
|
15 |
+
https://www.cnblogs.com/LXP-Never/p/16011229.html
|
16 |
+
1 / (24.7 * 9.265) = 0.00436976
|
17 |
+
"""
|
18 |
+
return 9.265 * math.log(freq_hz / (24.7 * 9.265) + 1)
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
def erb2freq(n_erb: float) -> float:
|
22 |
+
return 24.7 * 9.265 * (math.exp(n_erb / 9.265) - 1)
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def get_erb_widths(cls, sample_rate: int, nfft: int, erb_bins: int, min_freq_bins_for_erb: int) -> np.ndarray:
|
26 |
+
"""
|
27 |
+
https://github.com/Rikorose/DeepFilterNet/blob/main/libDF/src/lib.rs
|
28 |
+
:param sample_rate:
|
29 |
+
:param nfft:
|
30 |
+
:param erb_bins: erb (Equivalent Rectangular Bandwidth) 等效矩形带宽的通道数.
|
31 |
+
:param min_freq_bins_for_erb: Minimum number of frequency bands per erb band
|
32 |
+
:return:
|
33 |
+
"""
|
34 |
+
nyq_freq = sample_rate / 2.
|
35 |
+
freq_width: float = sample_rate / nfft
|
36 |
+
|
37 |
+
min_erb: float = cls.freq2erb(0.)
|
38 |
+
max_erb: float = cls.freq2erb(nyq_freq)
|
39 |
+
|
40 |
+
erb = [0] * erb_bins
|
41 |
+
step = (max_erb - min_erb) / erb_bins
|
42 |
+
|
43 |
+
prev_freq_bin = 0
|
44 |
+
freq_over = 0
|
45 |
+
for i in range(1, erb_bins + 1):
|
46 |
+
f = cls.erb2freq(min_erb + i * step)
|
47 |
+
freq_bin = int(round(f / freq_width))
|
48 |
+
freq_bins = freq_bin - prev_freq_bin - freq_over
|
49 |
+
|
50 |
+
if freq_bins < min_freq_bins_for_erb:
|
51 |
+
freq_over = min_freq_bins_for_erb - freq_bins
|
52 |
+
freq_bins = min_freq_bins_for_erb
|
53 |
+
else:
|
54 |
+
freq_over = 0
|
55 |
+
erb[i - 1] = freq_bins
|
56 |
+
prev_freq_bin = freq_bin
|
57 |
+
|
58 |
+
erb[erb_bins - 1] += 1
|
59 |
+
too_large = sum(erb) - (nfft / 2 + 1)
|
60 |
+
if too_large > 0:
|
61 |
+
erb[erb_bins - 1] -= too_large
|
62 |
+
return np.array(erb, dtype=np.uint64)
|
63 |
+
|
64 |
+
@staticmethod
|
65 |
+
def get_erb_filter_bank(erb_widths: np.ndarray,
|
66 |
+
normalized: bool = True,
|
67 |
+
inverse: bool = False,
|
68 |
+
):
|
69 |
+
num_freq_bins = int(np.sum(erb_widths))
|
70 |
+
num_erb_bins = len(erb_widths)
|
71 |
+
|
72 |
+
fb: np.ndarray = np.zeros(shape=(num_freq_bins, num_erb_bins))
|
73 |
+
|
74 |
+
points = np.cumsum([0] + erb_widths.tolist()).astype(int)[:-1]
|
75 |
+
for i, (b, w) in enumerate(zip(points.tolist(), erb_widths.tolist())):
|
76 |
+
fb[b: b + w, i] = 1
|
77 |
+
|
78 |
+
if inverse:
|
79 |
+
fb = fb.T
|
80 |
+
if not normalized:
|
81 |
+
fb /= np.sum(fb, axis=1, keepdims=True)
|
82 |
+
else:
|
83 |
+
if normalized:
|
84 |
+
fb /= np.sum(fb, axis=0)
|
85 |
+
return fb
|
86 |
+
|
87 |
+
@staticmethod
|
88 |
+
def spec2erb(spec: np.ndarray, erb_fb: np.ndarray, db: bool = True):
|
89 |
+
"""
|
90 |
+
ERB filterbank and transform to decibel scale.
|
91 |
+
|
92 |
+
:param spec: Spectrum of shape [B, C, T, F].
|
93 |
+
:param erb_fb: ERB filterbank array of shape [B] containing the ERB widths,
|
94 |
+
where B are the number of ERB bins.
|
95 |
+
:param db: Whether to transform the output into decibel scale. Defaults to `True`.
|
96 |
+
:return:
|
97 |
+
"""
|
98 |
+
# complex spec to power spec. (real * real + image * image)
|
99 |
+
spec_ = np.abs(spec) ** 2
|
100 |
+
|
101 |
+
# spec to erb feature.
|
102 |
+
erb_feat = np.matmul(spec_, erb_fb)
|
103 |
+
|
104 |
+
if db:
|
105 |
+
erb_feat = 10 * np.log10(erb_feat + 1e-10)
|
106 |
+
|
107 |
+
erb_feat = np.array(erb_feat, dtype=np.float32)
|
108 |
+
return erb_feat
|
109 |
+
|
110 |
+
|
111 |
+
class ErbBands(nn.Module):
|
112 |
+
def __init__(self,
|
113 |
+
sample_rate: int = 8000,
|
114 |
+
nfft: int = 512,
|
115 |
+
erb_bins: int = 32,
|
116 |
+
min_freq_bins_for_erb: int = 2,
|
117 |
+
):
|
118 |
+
super().__init__()
|
119 |
+
self.sample_rate = sample_rate
|
120 |
+
self.nfft = nfft
|
121 |
+
self.erb_bins = erb_bins
|
122 |
+
self.min_freq_bins_for_erb = min_freq_bins_for_erb
|
123 |
+
|
124 |
+
erb_fb, erb_fb_inv = self.init_erb_fb()
|
125 |
+
self.erb_fb = torch.tensor(erb_fb, dtype=torch.float32, requires_grad=False)
|
126 |
+
self.erb_fb_inv = torch.tensor(erb_fb_inv, dtype=torch.float32, requires_grad=False)
|
127 |
+
|
128 |
+
def init_erb_fb(self):
|
129 |
+
erb_widths = ErbBandsNumpy.get_erb_widths(
|
130 |
+
sample_rate=self.sample_rate,
|
131 |
+
nfft=self.nfft,
|
132 |
+
erb_bins=self.erb_bins,
|
133 |
+
min_freq_bins_for_erb=self.min_freq_bins_for_erb,
|
134 |
+
)
|
135 |
+
erb_fb = ErbBandsNumpy.get_erb_filter_bank(
|
136 |
+
erb_widths=erb_widths,
|
137 |
+
normalized=True,
|
138 |
+
inverse=False,
|
139 |
+
)
|
140 |
+
erb_fb_inv = ErbBandsNumpy.get_erb_filter_bank(
|
141 |
+
erb_widths=erb_widths,
|
142 |
+
normalized=True,
|
143 |
+
inverse=True,
|
144 |
+
)
|
145 |
+
return erb_fb, erb_fb_inv
|
146 |
+
|
147 |
+
def erb_scale(self, spec: torch.Tensor, db: bool = True):
|
148 |
+
spec_erb = torch.matmul(spec, self.erb_fb)
|
149 |
+
if db:
|
150 |
+
spec_erb = 10 * torch.log10(spec_erb + 1e-10)
|
151 |
+
return spec_erb
|
152 |
+
|
153 |
+
def erb_scale_inv(self, spec_erb: torch.Tensor):
|
154 |
+
spec = torch.matmul(spec_erb, self.erb_fb_inv)
|
155 |
+
return spec
|
156 |
+
|
157 |
+
|
158 |
+
def main():
|
159 |
+
|
160 |
+
erb_bands = ErbBands()
|
161 |
+
|
162 |
+
spec = torch.randn(size=(2, 199, 257), dtype=torch.float32)
|
163 |
+
spec_erb = erb_bands.erb_scale(spec)
|
164 |
+
print(spec_erb.shape)
|
165 |
+
|
166 |
+
spec = erb_bands.erb_scale_inv(spec_erb)
|
167 |
+
print(spec.shape)
|
168 |
+
|
169 |
+
return
|
170 |
+
|
171 |
+
|
172 |
+
if __name__ == "__main__":
|
173 |
+
main()
|