HoneyTian commited on
Commit
85a1b16
·
1 Parent(s): 8c3c188

add microphone audio input

Browse files
Dockerfile CHANGED
@@ -4,6 +4,9 @@ WORKDIR /code
4
 
5
  COPY . /code
6
 
 
 
 
7
  RUN pip install --upgrade pip
8
  RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
9
 
 
4
 
5
  COPY . /code
6
 
7
+ RUN apt-get update
8
+ RUN apt-get install -y ffmpeg build-essential
9
+
10
  RUN pip install --upgrade pip
11
  RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
12
 
examples/dfnet/step_2_train_model.py CHANGED
@@ -15,6 +15,8 @@ import sys
15
  import shutil
16
  from typing import List
17
 
 
 
18
  pwd = os.path.abspath(os.path.dirname(__file__))
19
  sys.path.append(os.path.join(pwd, "../../"))
20
 
@@ -243,7 +245,11 @@ def main():
243
  step_idx = 0 if last_step_idx == -1 else last_step_idx
244
 
245
  logger.info("training")
 
246
  for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
 
 
 
247
  # train
248
  model.train()
249
 
 
15
  import shutil
16
  from typing import List
17
 
18
+ from fontTools.varLib.plot import stops
19
+
20
  pwd = os.path.abspath(os.path.dirname(__file__))
21
  sys.path.append(os.path.join(pwd, "../../"))
22
 
 
245
  step_idx = 0 if last_step_idx == -1 else last_step_idx
246
 
247
  logger.info("training")
248
+ early_stop_flag = False
249
  for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
250
+ if early_stop_flag:
251
+ break
252
+
253
  # train
254
  model.train()
255
 
examples/dfnet/yaml/config.yaml CHANGED
@@ -68,7 +68,7 @@ seed: 1234
68
 
69
  num_workers: 8
70
  batch_size: 32
71
- eval_steps: 10000
72
 
73
  # runtime
74
  use_post_filter: true
 
68
 
69
  num_workers: 8
70
  batch_size: 32
71
+ eval_steps: 25000
72
 
73
  # runtime
74
  use_post_filter: true
examples/dtln/step_2_train_model.py CHANGED
@@ -235,7 +235,11 @@ def main():
235
  step_idx = 0 if last_step_idx == -1 else last_step_idx
236
 
237
  logger.info("training")
 
238
  for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
 
 
 
239
  # train
240
  model.train()
241
 
 
235
  step_idx = 0 if last_step_idx == -1 else last_step_idx
236
 
237
  logger.info("training")
238
+ early_stop_flag = False
239
  for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
240
+ if early_stop_flag:
241
+ break
242
+
243
  # train
244
  model.train()
245
 
examples/dtln/yaml/config.yaml CHANGED
@@ -1,23 +1,29 @@
1
  model_name: "DTLN"
2
 
 
3
  sample_rate: 8000
4
  fft_size: 256
5
  hop_size: 128
6
  win_type: hann
7
 
 
8
  max_snr_db: 20
9
  min_snr_db: -10
10
 
 
11
  encoder_size: 256
12
 
13
- max_epochs: 100
14
- batch_size: 4
15
- num_workers: 4
16
- seed: 1234
17
- eval_steps: 25000
18
-
19
  lr: 0.001
20
- lr_scheduler: CosineAnnealingLR
21
- lr_scheduler_kwargs: {}
 
 
22
 
 
23
  clip_grad_norm: 10.0
 
 
 
 
 
 
1
  model_name: "DTLN"
2
 
3
+ # spec
4
  sample_rate: 8000
5
  fft_size: 256
6
  hop_size: 128
7
  win_type: hann
8
 
9
+ # data
10
  max_snr_db: 20
11
  min_snr_db: -10
12
 
13
+ # model
14
  encoder_size: 256
15
 
16
+ # train
 
 
 
 
 
17
  lr: 0.001
18
+ lr_scheduler: "CosineAnnealingLR"
19
+ lr_scheduler_kwargs:
20
+ T_max: 250000
21
+ eta_min: 0.0001
22
 
23
+ max_epochs: 100
24
  clip_grad_norm: 10.0
25
+ seed: 1234
26
+
27
+ batch_size: 32
28
+ num_workers: 4
29
+ eval_steps: 25000
examples/frcrn/step_2_train_model.py CHANGED
@@ -238,7 +238,11 @@ def main():
238
  step_idx = 0 if last_step_idx == -1 else last_step_idx
239
 
240
  logger.info("training")
 
241
  for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
 
 
 
242
  # train
243
  model.train()
244
 
 
238
  step_idx = 0 if last_step_idx == -1 else last_step_idx
239
 
240
  logger.info("training")
241
+ early_stop_flag = False
242
  for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
243
+ if early_stop_flag:
244
+ break
245
+
246
  # train
247
  model.train()
248
 
examples/{simple_lstm_irm → lstm}/run.sh RENAMED
File without changes
examples/{simple_lstm_irm → lstm}/step_1_prepare_data.py RENAMED
File without changes
examples/lstm/step_2_train_model.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/WenzheLiu-Speech/awesome-speech-enhancement
5
+ """
6
+ import argparse
7
+ import json
8
+ import logging
9
+ from logging.handlers import TimedRotatingFileHandler
10
+ import os
11
+ import platform
12
+ from pathlib import Path
13
+ import random
14
+ import sys
15
+ import shutil
16
+ from typing import List
17
+
18
+ pwd = os.path.abspath(os.path.dirname(__file__))
19
+ sys.path.append(os.path.join(pwd, "../../"))
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn as nn
24
+ from torch.utils.data.dataloader import DataLoader
25
+ import torchaudio
26
+ from tqdm import tqdm
27
+
28
+ from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
29
+ from toolbox.torchaudio.metrics.pesq import run_pesq_score
30
+ from toolbox.torchaudio.models.lstm.configuration_lstm import LstmConfig
31
+ from toolbox.torchaudio.models.lstm.modeling_lstm import LstmPretrainedModel
32
+
33
+
34
+ def get_args():
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
37
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
38
+ parser.add_argument("--max_epochs", default=100, type=int)
39
+
40
+ parser.add_argument("--batch_size", default=64, type=int)
41
+ parser.add_argument("--learning_rate", default=1e-3, type=float)
42
+ parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
43
+ parser.add_argument("--patience", default=10, type=int)
44
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
45
+ parser.add_argument("--seed", default=0, type=int)
46
+
47
+ parser.add_argument("--config_file", default="config.yaml", type=str)
48
+
49
+ args = parser.parse_args()
50
+ return args
51
+
52
+
53
+ def logging_config(file_dir: str):
54
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
55
+
56
+ logging.basicConfig(format=fmt,
57
+ datefmt="%m/%d/%Y %H:%M:%S",
58
+ level=logging.INFO)
59
+ file_handler = TimedRotatingFileHandler(
60
+ filename=os.path.join(file_dir, "main.log"),
61
+ encoding="utf-8",
62
+ when="D",
63
+ interval=1,
64
+ backupCount=7
65
+ )
66
+ file_handler.setLevel(logging.INFO)
67
+ file_handler.setFormatter(logging.Formatter(fmt))
68
+ logger = logging.getLogger(__name__)
69
+ logger.addHandler(file_handler)
70
+
71
+ return logger
72
+
73
+
74
+ class CollateFunction(object):
75
+ def __init__(self,
76
+ n_fft: int = 512,
77
+ win_length: int = 200,
78
+ hop_length: int = 80,
79
+ window_fn: str = "hamming",
80
+ irm_beta: float = 1.0,
81
+ epsilon: float = 1e-8,
82
+ ):
83
+ self.n_fft = n_fft
84
+ self.win_length = win_length
85
+ self.hop_length = hop_length
86
+ self.window_fn = window_fn
87
+ self.irm_beta = irm_beta
88
+ self.epsilon = epsilon
89
+
90
+ self.stft_mag = torchaudio.transforms.Spectrogram(
91
+ n_fft=self.n_fft,
92
+ win_length=self.win_length,
93
+ hop_length=self.hop_length,
94
+ power=1.0,
95
+ window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
96
+ )
97
+ self.stft_complex = torchaudio.transforms.Spectrogram(
98
+ n_fft=self.n_fft,
99
+ win_length=self.win_length,
100
+ hop_length=self.hop_length,
101
+ power=None,
102
+ window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
103
+ )
104
+
105
+ self.istft = torchaudio.transforms.InverseSpectrogram(
106
+ n_fft=self.n_fft,
107
+ win_length=self.win_length,
108
+ hop_length=self.hop_length,
109
+ window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
110
+ )
111
+
112
+ def __call__(self, batch: List[dict]):
113
+ mag_noisy_audios = list()
114
+ pha_noisy_audios = list()
115
+ irm_gth = list()
116
+
117
+ clean_audios = list()
118
+
119
+ for sample in batch:
120
+ noise_audio: torch.Tensor = sample["noise_wave"]
121
+ clean_audio: torch.Tensor = sample["speech_wave"]
122
+ noisy_audio: torch.Tensor = sample["mix_wave"]
123
+ snr_db: float = sample["snr_db"]
124
+
125
+ mag_noise = self.stft_mag.forward(noise_audio)
126
+ mag_clean = self.stft_mag.forward(clean_audio)
127
+ stft_noisy = self.stft_complex.forward(noisy_audio)
128
+
129
+ irm_clean = mag_clean / (mag_noise + mag_clean + self.epsilon)
130
+ irm_clean = torch.pow(irm_clean, self.irm_beta)
131
+
132
+ real = torch.real(stft_noisy)
133
+ imag = torch.imag(stft_noisy)
134
+ mag_noisy = torch.sqrt(real ** 2 + imag ** 2)
135
+ pha_noisy = torch.atan2(imag, real)
136
+
137
+ mag_noisy_audios.append(mag_noisy)
138
+ pha_noisy_audios.append(pha_noisy)
139
+ irm_gth.append(irm_clean)
140
+ clean_audios.append(clean_audio)
141
+
142
+ mag_noisy_audios = torch.stack(mag_noisy_audios)
143
+ pha_noisy_audios = torch.stack(pha_noisy_audios)
144
+ irm_gth = torch.stack(irm_gth)
145
+ clean_audios = torch.stack(clean_audios)
146
+
147
+ # assert
148
+ if torch.any(torch.isnan(mag_noisy_audios)):
149
+ raise AssertionError("nan in mag_noisy_audios Tensor")
150
+ if torch.any(torch.isnan(pha_noisy_audios)):
151
+ raise AssertionError("nan in pha_noisy_audios Tensor")
152
+ if torch.any(torch.isnan(irm_gth)):
153
+ raise AssertionError("nan in irm_gth Tensor")
154
+ if torch.any(torch.isnan(clean_audios)):
155
+ raise AssertionError("nan in clean_audios Tensor")
156
+
157
+ return mag_noisy_audios, pha_noisy_audios, irm_gth, clean_audios
158
+
159
+ def enhance(self, mag_noisy: torch.Tensor, pha_noisy: torch.Tensor, irm_speech: torch.Tensor):
160
+ mag_denoise = mag_noisy * irm_speech
161
+ stft_denoise = mag_denoise * torch.exp((1j * pha_noisy))
162
+ denoise = self.istft.forward(stft_denoise)
163
+ return denoise
164
+
165
+
166
+ collate_fn = CollateFunction()
167
+
168
+
169
+ def main():
170
+ args = get_args()
171
+
172
+ config = LstmConfig.from_pretrained(
173
+ pretrained_model_name_or_path=args.config_file,
174
+ )
175
+
176
+ serialization_dir = Path(args.serialization_dir)
177
+ serialization_dir.mkdir(parents=True, exist_ok=True)
178
+
179
+ logger = logging_config(serialization_dir)
180
+
181
+ random.seed(args.seed)
182
+ np.random.seed(args.seed)
183
+ torch.manual_seed(args.seed)
184
+ logger.info("set seed: {}".format(args.seed))
185
+
186
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
187
+ n_gpu = torch.cuda.device_count()
188
+ logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
189
+
190
+ # datasets
191
+ logger.info("prepare datasets")
192
+ train_dataset = DenoiseJsonlDataset(
193
+ jsonl_file=args.train_dataset,
194
+ expected_sample_rate=config.sample_rate,
195
+ max_wave_value=32768.0,
196
+ min_snr_db=config.min_snr_db,
197
+ max_snr_db=config.max_snr_db,
198
+ # skip=225000,
199
+ )
200
+ valid_dataset = DenoiseJsonlDataset(
201
+ jsonl_file=args.valid_dataset,
202
+ expected_sample_rate=config.sample_rate,
203
+ max_wave_value=32768.0,
204
+ min_snr_db=config.min_snr_db,
205
+ max_snr_db=config.max_snr_db,
206
+ )
207
+ train_data_loader = DataLoader(
208
+ dataset=train_dataset,
209
+ batch_size=config.batch_size,
210
+ # shuffle=True,
211
+ sampler=None,
212
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
213
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
214
+ collate_fn=collate_fn,
215
+ pin_memory=False,
216
+ prefetch_factor=None if platform.system() == "Windows" else 2,
217
+ )
218
+ valid_data_loader = DataLoader(
219
+ dataset=valid_dataset,
220
+ batch_size=config.batch_size,
221
+ # shuffle=True,
222
+ sampler=None,
223
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
224
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
225
+ collate_fn=collate_fn,
226
+ pin_memory=False,
227
+ prefetch_factor=None if platform.system() == "Windows" else 2,
228
+ )
229
+
230
+ # models
231
+ logger.info(f"prepare models. config_file: {args.config_file}")
232
+ model = LstmPretrainedModel(
233
+ config=config,
234
+ )
235
+ model.to(device)
236
+ model.train()
237
+
238
+ # optimizer
239
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
240
+ optimizer = torch.optim.AdamW(model.parameters(), config.lr)
241
+
242
+ # resume training
243
+ last_step_idx = -1
244
+ last_epoch = -1
245
+ for step_idx_str in serialization_dir.glob("steps-*"):
246
+ step_idx_str = Path(step_idx_str)
247
+ step_idx = step_idx_str.stem.split("-")[1]
248
+ step_idx = int(step_idx)
249
+ if step_idx > last_step_idx:
250
+ last_step_idx = step_idx
251
+ # last_epoch = 1
252
+
253
+ if last_step_idx != -1:
254
+ logger.info(f"resume from steps-{last_step_idx}.")
255
+ model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
256
+ optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
257
+
258
+ logger.info(f"load state dict for model.")
259
+ with open(model_pt.as_posix(), "rb") as f:
260
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
261
+ model.load_state_dict(state_dict, strict=True)
262
+
263
+ logger.info(f"load state dict for optimizer.")
264
+ with open(optimizer_pth.as_posix(), "rb") as f:
265
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
266
+ optimizer.load_state_dict(state_dict)
267
+
268
+ if config.lr_scheduler == "CosineAnnealingLR":
269
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
270
+ optimizer,
271
+ last_epoch=last_epoch,
272
+ # T_max=10 * config.eval_steps,
273
+ # eta_min=0.01 * config.lr,
274
+ **config.lr_scheduler_kwargs,
275
+ )
276
+ elif config.lr_scheduler == "MultiStepLR":
277
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
278
+ optimizer,
279
+ last_epoch=last_epoch,
280
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
281
+ )
282
+ else:
283
+ raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
284
+
285
+ mse_loss_fn = nn.MSELoss(
286
+ reduction="mean",
287
+ ).to(device)
288
+
289
+ # training loop
290
+ logger.info("training")
291
+
292
+ average_pesq_score = 1000000000
293
+ average_loss = 1000000000
294
+
295
+ model_list = list()
296
+ best_epoch_idx = None
297
+ best_step_idx = None
298
+ best_metric = None
299
+ patience_count = 0
300
+
301
+ step_idx = 0 if last_step_idx == -1 else last_step_idx
302
+
303
+ logger.info("training")
304
+ early_stop_flag = False
305
+ for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
306
+ if early_stop_flag:
307
+ break
308
+
309
+ # train
310
+ model.train()
311
+
312
+ total_pesq_score = 0.
313
+ total_loss = 0.
314
+ total_batches = 0.
315
+
316
+ progress_bar_train = tqdm(
317
+ initial=step_idx,
318
+ desc="Training; epoch: {}".format(epoch_idx),
319
+ )
320
+ for train_batch in train_data_loader:
321
+ mag_noisy_audios, pha_noisy_audios, irm_gth, clean_audios = train_batch
322
+ mag_noisy_audios = mag_noisy_audios.to(device)
323
+ pha_noisy_audios = pha_noisy_audios.to(device)
324
+ irm_gth = irm_gth.to(device)
325
+ clean_audios = clean_audios.to(device)
326
+
327
+ irm = model.forward(mag_noisy_audios)
328
+ denoise_audios = collate_fn.enhance(mag_noisy_audios, pha_noisy_audios, irm)
329
+ loss = mse_loss_fn.forward(irm, irm_gth)
330
+
331
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
332
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
333
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
334
+
335
+ optimizer.zero_grad()
336
+ loss.backward()
337
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
338
+ optimizer.step()
339
+ lr_scheduler.step()
340
+
341
+ total_pesq_score += pesq_score
342
+ total_loss += loss.item()
343
+ total_batches += 1
344
+
345
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
346
+ average_loss = round(total_loss / total_batches, 4)
347
+
348
+ progress_bar_train.update(1)
349
+ progress_bar_train.set_postfix({
350
+ "lr": lr_scheduler.get_last_lr()[0],
351
+ "pesq_score": average_pesq_score,
352
+ "loss": average_loss,
353
+ })
354
+
355
+ # evaluation
356
+ step_idx += 1
357
+ if step_idx % config.eval_steps == 0:
358
+ with torch.no_grad():
359
+ torch.cuda.empty_cache()
360
+
361
+ total_pesq_score = 0.
362
+ total_loss = 0.
363
+ total_batches = 0.
364
+
365
+ progress_bar_train.close()
366
+ progress_bar_eval = tqdm(
367
+ desc="Evaluation; steps-{}k".format(int(step_idx / 1000)),
368
+ )
369
+
370
+ for eval_batch in valid_data_loader:
371
+ mag_noisy_audios, pha_noisy_audios, irm_gth, clean_audios = eval_batch
372
+ mag_noisy_audios = mag_noisy_audios.to(device)
373
+ pha_noisy_audios = pha_noisy_audios.to(device)
374
+ irm_gth = irm_gth.to(device)
375
+ clean_audios = clean_audios.to(device)
376
+
377
+ with torch.no_grad():
378
+ irm = model.forward(mag_noisy_audios)
379
+ denoise_audios = collate_fn.enhance(mag_noisy_audios, pha_noisy_audios, irm)
380
+ loss = mse_loss_fn.forward(irm, irm_gth)
381
+
382
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
383
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
384
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
385
+
386
+ optimizer.zero_grad()
387
+ loss.backward()
388
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
389
+ optimizer.step()
390
+ lr_scheduler.step()
391
+
392
+ total_pesq_score += pesq_score
393
+ total_loss += loss.item()
394
+ total_batches += 1
395
+
396
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
397
+ average_loss = round(total_loss / total_batches, 4)
398
+
399
+ progress_bar_eval.update(1)
400
+ progress_bar_eval.set_postfix({
401
+ "lr": lr_scheduler.get_last_lr()[0],
402
+ "pesq_score": average_pesq_score,
403
+ "loss": average_loss,
404
+ })
405
+
406
+ total_pesq_score = 0.
407
+ total_loss = 0.
408
+ total_batches = 0.
409
+
410
+ progress_bar_eval.close()
411
+ progress_bar_train = tqdm(
412
+ initial=progress_bar_train.n,
413
+ postfix=progress_bar_train.postfix,
414
+ desc=progress_bar_train.desc,
415
+ )
416
+
417
+ # save path
418
+ epoch_dir = serialization_dir / "epoch-{}".format(epoch_idx)
419
+ epoch_dir.mkdir(parents=True, exist_ok=False)
420
+
421
+ # save models
422
+ model.save_pretrained(epoch_dir.as_posix())
423
+
424
+ model_list.append(epoch_dir)
425
+ if len(model_list) >= args.num_serialized_models_to_keep:
426
+ model_to_delete: Path = model_list.pop(0)
427
+ shutil.rmtree(model_to_delete.as_posix())
428
+
429
+ # save metric
430
+ if best_metric is None:
431
+ best_epoch_idx = epoch_idx
432
+ best_step_idx = step_idx
433
+ best_metric = average_pesq_score
434
+ elif average_pesq_score >= best_metric:
435
+ # great is better.
436
+ best_epoch_idx = epoch_idx
437
+ best_step_idx = step_idx
438
+ best_metric = average_pesq_score
439
+ else:
440
+ pass
441
+
442
+ metrics = {
443
+ "epoch_idx": epoch_idx,
444
+ "best_epoch_idx": best_epoch_idx,
445
+ "best_step_idx": best_step_idx,
446
+ "pesq_score": average_pesq_score,
447
+ "loss": average_loss,
448
+ }
449
+ metrics_filename = epoch_dir / "metrics_epoch.json"
450
+ with open(metrics_filename, "w", encoding="utf-8") as f:
451
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
452
+
453
+ # save best
454
+ best_dir = serialization_dir / "best"
455
+ if best_epoch_idx == epoch_idx:
456
+ if best_dir.exists():
457
+ shutil.rmtree(best_dir)
458
+ shutil.copytree(epoch_dir, best_dir)
459
+
460
+ # early stop
461
+ early_stop_flag = False
462
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
463
+ patience_count = 0
464
+ else:
465
+ patience_count += 1
466
+ if patience_count >= args.patience:
467
+ early_stop_flag = True
468
+
469
+ # early stop
470
+ if early_stop_flag:
471
+ break
472
+ return
473
+
474
+
475
+ if __name__ == '__main__':
476
+ main()
examples/{simple_lstm_irm → lstm}/step_3_evaluation.py RENAMED
@@ -19,7 +19,7 @@ import torch.nn as nn
19
  import torchaudio
20
  from tqdm import tqdm
21
 
22
- from toolbox.torchaudio.models.simple_lstm_irm.modeling_simple_lstm_irm import SimpleLstmIRMPretrainedModel
23
 
24
 
25
  def get_args():
@@ -147,7 +147,7 @@ def main():
147
  logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
148
 
149
  logger.info("prepare model")
150
- model = SimpleLstmIRMPretrainedModel.from_pretrained(
151
  pretrained_model_name_or_path=args.model_dir,
152
  )
153
  model.to(device)
 
19
  import torchaudio
20
  from tqdm import tqdm
21
 
22
+ from toolbox.torchaudio.models.lstm.modeling_lstm import LstmPretrainedModel
23
 
24
 
25
  def get_args():
 
147
  logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
148
 
149
  logger.info("prepare model")
150
+ model = LstmPretrainedModel.from_pretrained(
151
  pretrained_model_name_or_path=args.model_dir,
152
  )
153
  model.to(device)
examples/mpnet/step_2_train_model.py CHANGED
@@ -225,7 +225,11 @@ def main():
225
  patience_count = 0
226
 
227
  logger.info("training")
 
228
  for idx_epoch in range(max(0, last_epoch+1), args.max_epochs):
 
 
 
229
  # train
230
  generator.train()
231
  discriminator.train()
 
225
  patience_count = 0
226
 
227
  logger.info("training")
228
+ early_stop_flag = False
229
  for idx_epoch in range(max(0, last_epoch+1), args.max_epochs):
230
+ if early_stop_flag:
231
+ break
232
+
233
  # train
234
  generator.train()
235
  discriminator.train()
examples/simple_lstm_irm/step_2_train_model.py DELETED
@@ -1,346 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- https://github.com/WenzheLiu-Speech/awesome-speech-enhancement
5
- """
6
- import argparse
7
- import json
8
- import logging
9
- from logging.handlers import TimedRotatingFileHandler
10
- import os
11
- import platform
12
- from pathlib import Path
13
- import random
14
- import sys
15
- import shutil
16
- from typing import List
17
-
18
- pwd = os.path.abspath(os.path.dirname(__file__))
19
- sys.path.append(os.path.join(pwd, "../../"))
20
-
21
- import numpy as np
22
- import torch
23
- import torch.nn as nn
24
- from torch.utils.data.dataloader import DataLoader
25
- import torchaudio
26
- from tqdm import tqdm
27
-
28
- from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
29
- from toolbox.torchaudio.models.simple_lstm_irm.configuration_simple_lstm_irm import SimpleLstmIRMConfig
30
- from toolbox.torchaudio.models.simple_lstm_irm.modeling_simple_lstm_irm import SimpleLstmIRMPretrainedModel
31
-
32
-
33
- def get_args():
34
- parser = argparse.ArgumentParser()
35
- parser.add_argument("--train_dataset", default="train.xlsx", type=str)
36
- parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
37
-
38
- parser.add_argument("--max_epochs", default=100, type=int)
39
-
40
- parser.add_argument("--batch_size", default=64, type=int)
41
- parser.add_argument("--learning_rate", default=1e-3, type=float)
42
- parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
43
- parser.add_argument("--patience", default=5, type=int)
44
- parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
45
- parser.add_argument("--seed", default=0, type=int)
46
-
47
- parser.add_argument("--config_file", default="config.yaml", type=str)
48
-
49
- args = parser.parse_args()
50
- return args
51
-
52
-
53
- def logging_config(file_dir: str):
54
- fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
55
-
56
- logging.basicConfig(format=fmt,
57
- datefmt="%m/%d/%Y %H:%M:%S",
58
- level=logging.INFO)
59
- file_handler = TimedRotatingFileHandler(
60
- filename=os.path.join(file_dir, "main.log"),
61
- encoding="utf-8",
62
- when="D",
63
- interval=1,
64
- backupCount=7
65
- )
66
- file_handler.setLevel(logging.INFO)
67
- file_handler.setFormatter(logging.Formatter(fmt))
68
- logger = logging.getLogger(__name__)
69
- logger.addHandler(file_handler)
70
-
71
- return logger
72
-
73
-
74
- class CollateFunction(object):
75
- def __init__(self,
76
- n_fft: int = 512,
77
- win_length: int = 200,
78
- hop_length: int = 80,
79
- window_fn: str = "hamming",
80
- irm_beta: float = 1.0,
81
- epsilon: float = 1e-8,
82
- ):
83
- self.n_fft = n_fft
84
- self.win_length = win_length
85
- self.hop_length = hop_length
86
- self.window_fn = window_fn
87
- self.irm_beta = irm_beta
88
- self.epsilon = epsilon
89
-
90
- self.transform = torchaudio.transforms.Spectrogram(
91
- n_fft=self.n_fft,
92
- win_length=self.win_length,
93
- hop_length=self.hop_length,
94
- power=2.0,
95
- window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
96
- )
97
-
98
- def __call__(self, batch: List[dict]):
99
- mix_spec_list = list()
100
- speech_irm_list = list()
101
- snr_db_list = list()
102
- for sample in batch:
103
- noise_wave: torch.Tensor = sample["noise_wave"]
104
- speech_wave: torch.Tensor = sample["speech_wave"]
105
- mix_wave: torch.Tensor = sample["mix_wave"]
106
- snr_db: float = sample["snr_db"]
107
-
108
- noise_spec = self.transform.forward(noise_wave)
109
- speech_spec = self.transform.forward(speech_wave)
110
- mix_spec = self.transform.forward(mix_wave)
111
-
112
- # noise_irm = noise_spec / (noise_spec + speech_spec)
113
- speech_irm = speech_spec / (noise_spec + speech_spec + self.epsilon)
114
- speech_irm = torch.pow(speech_irm, self.irm_beta)
115
-
116
- mix_spec_list.append(mix_spec)
117
- speech_irm_list.append(speech_irm)
118
- snr_db_list.append(torch.tensor(snr_db, dtype=torch.float32))
119
-
120
- mix_spec_list = torch.stack(mix_spec_list)
121
- speech_irm_list = torch.stack(speech_irm_list)
122
- snr_db_list = torch.stack(snr_db_list) # shape: (batch_size,)
123
-
124
- # assert
125
- if torch.any(torch.isnan(mix_spec_list)):
126
- raise AssertionError("nan in mix_spec Tensor")
127
- if torch.any(torch.isnan(speech_irm_list)):
128
- raise AssertionError("nan in speech_irm Tensor")
129
- if torch.any(torch.isnan(snr_db_list)):
130
- raise AssertionError("nan in snr_db Tensor")
131
-
132
- return mix_spec_list, speech_irm_list, snr_db_list
133
-
134
-
135
- collate_fn = CollateFunction()
136
-
137
-
138
- def main():
139
- args = get_args()
140
-
141
- serialization_dir = Path(args.serialization_dir)
142
- serialization_dir.mkdir(parents=True, exist_ok=True)
143
-
144
- logger = logging_config(serialization_dir)
145
-
146
- random.seed(args.seed)
147
- np.random.seed(args.seed)
148
- torch.manual_seed(args.seed)
149
- logger.info("set seed: {}".format(args.seed))
150
-
151
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
152
- n_gpu = torch.cuda.device_count()
153
- logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
154
-
155
- # datasets
156
- logger.info("prepare datasets")
157
- train_dataset = DenoiseExcelDataset(
158
- excel_file=args.train_dataset,
159
- expected_sample_rate=8000,
160
- max_wave_value=32768.0,
161
- )
162
- valid_dataset = DenoiseExcelDataset(
163
- excel_file=args.valid_dataset,
164
- expected_sample_rate=8000,
165
- max_wave_value=32768.0,
166
- )
167
- train_data_loader = DataLoader(
168
- dataset=train_dataset,
169
- batch_size=args.batch_size,
170
- shuffle=True,
171
- # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
172
- num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
173
- collate_fn=collate_fn,
174
- pin_memory=False,
175
- # prefetch_factor=64,
176
- )
177
- valid_data_loader = DataLoader(
178
- dataset=valid_dataset,
179
- batch_size=args.batch_size,
180
- shuffle=True,
181
- # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
182
- num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
183
- collate_fn=collate_fn,
184
- pin_memory=False,
185
- # prefetch_factor=64,
186
- )
187
-
188
- # models
189
- logger.info(f"prepare models. config_file: {args.config_file}")
190
- config = SimpleLstmIRMConfig.from_pretrained(
191
- pretrained_model_name_or_path=args.config_file,
192
- # num_labels=vocabulary.get_vocab_size(namespace="labels")
193
- )
194
- model = SimpleLstmIRMPretrainedModel(
195
- config=config,
196
- )
197
- model.to(device)
198
- model.train()
199
-
200
- # optimizer
201
- logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
202
- param_optimizer = model.parameters()
203
- optimizer = torch.optim.Adam(
204
- param_optimizer,
205
- lr=args.learning_rate,
206
- )
207
- # lr_scheduler = torch.optim.lr_scheduler.StepLR(
208
- # optimizer,
209
- # step_size=2000
210
- # )
211
- lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
212
- optimizer,
213
- milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
214
- )
215
- mse_loss = nn.MSELoss(
216
- reduction="mean",
217
- )
218
-
219
- # training loop
220
- logger.info("training")
221
-
222
- training_loss = 10000000000
223
- evaluation_loss = 10000000000
224
-
225
- model_list = list()
226
- best_idx_epoch = None
227
- best_metric = None
228
- patience_count = 0
229
-
230
- for idx_epoch in range(args.max_epochs):
231
- total_loss = 0.
232
- total_examples = 0.
233
- progress_bar = tqdm(
234
- total=len(train_data_loader),
235
- desc="Training; epoch: {}".format(idx_epoch),
236
- )
237
-
238
- for batch in train_data_loader:
239
- mix_spec, speech_irm, snr_db = batch
240
- mix_spec = mix_spec.to(device)
241
- speech_irm_target = speech_irm.to(device)
242
- snr_db_target = snr_db.to(device)
243
-
244
- speech_irm_prediction = model.forward(mix_spec)
245
- loss = mse_loss.forward(speech_irm_prediction, speech_irm_target)
246
-
247
- total_loss += loss.item()
248
- total_examples += mix_spec.size(0)
249
-
250
- optimizer.zero_grad()
251
- loss.backward()
252
- optimizer.step()
253
- lr_scheduler.step()
254
-
255
- training_loss = total_loss / total_examples
256
- training_loss = round(training_loss, 4)
257
-
258
- progress_bar.update(1)
259
- progress_bar.set_postfix({
260
- "training_loss": training_loss,
261
- })
262
-
263
- total_loss = 0.
264
- total_examples = 0.
265
- progress_bar = tqdm(
266
- total=len(valid_data_loader),
267
- desc="Evaluation; epoch: {}".format(idx_epoch),
268
- )
269
- for batch in valid_data_loader:
270
- mix_spec, speech_irm, snr_db = batch
271
- mix_spec = mix_spec.to(device)
272
- speech_irm_target = speech_irm.to(device)
273
- snr_db_target = snr_db.to(device)
274
-
275
- with torch.no_grad():
276
- speech_irm_prediction = model.forward(mix_spec)
277
- loss = mse_loss.forward(speech_irm_prediction, speech_irm_target)
278
-
279
- total_loss += loss.item()
280
- total_examples += mix_spec.size(0)
281
-
282
- evaluation_loss = total_loss / total_examples
283
- evaluation_loss = round(evaluation_loss, 4)
284
-
285
- progress_bar.update(1)
286
- progress_bar.set_postfix({
287
- "evaluation_loss": evaluation_loss,
288
- })
289
-
290
- # save path
291
- epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
292
- epoch_dir.mkdir(parents=True, exist_ok=False)
293
-
294
- # save models
295
- model.save_pretrained(epoch_dir.as_posix())
296
-
297
- model_list.append(epoch_dir)
298
- if len(model_list) >= args.num_serialized_models_to_keep:
299
- model_to_delete: Path = model_list.pop(0)
300
- shutil.rmtree(model_to_delete.as_posix())
301
-
302
- # save metric
303
- if best_metric is None:
304
- best_idx_epoch = idx_epoch
305
- best_metric = evaluation_loss
306
- elif evaluation_loss < best_metric:
307
- best_idx_epoch = idx_epoch
308
- best_metric = evaluation_loss
309
- else:
310
- pass
311
-
312
- metrics = {
313
- "idx_epoch": idx_epoch,
314
- "best_idx_epoch": best_idx_epoch,
315
- "training_loss": training_loss,
316
- "evaluation_loss": evaluation_loss,
317
- "learning_rate": optimizer.param_groups[0]["lr"],
318
- }
319
- metrics_filename = epoch_dir / "metrics_epoch.json"
320
- with open(metrics_filename, "w", encoding="utf-8") as f:
321
- json.dump(metrics, f, indent=4, ensure_ascii=False)
322
-
323
- # save best
324
- best_dir = serialization_dir / "best"
325
- if best_idx_epoch == idx_epoch:
326
- if best_dir.exists():
327
- shutil.rmtree(best_dir)
328
- shutil.copytree(epoch_dir, best_dir)
329
-
330
- # early stop
331
- early_stop_flag = False
332
- if best_idx_epoch == idx_epoch:
333
- patience_count = 0
334
- else:
335
- patience_count += 1
336
- if patience_count >= args.patience:
337
- early_stop_flag = True
338
-
339
- # early stop
340
- if early_stop_flag:
341
- break
342
- return
343
-
344
-
345
- if __name__ == '__main__':
346
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/dfnet/configuration_dfnet.py CHANGED
@@ -14,6 +14,8 @@ class DfNetConfig(PretrainedConfig):
14
  win_type: str = "hann",
15
 
16
  spec_bins: int = 256,
 
 
17
 
18
  conv_channels: int = 64,
19
  conv_kernel_size_input: Tuple[int, int] = (3, 3),
@@ -79,6 +81,8 @@ class DfNetConfig(PretrainedConfig):
79
 
80
  # spectrum
81
  self.spec_bins = spec_bins
 
 
82
 
83
  # conv
84
  self.conv_channels = conv_channels
 
14
  win_type: str = "hann",
15
 
16
  spec_bins: int = 256,
17
+ erb_bins: int = 32,
18
+ min_freq_bins_for_erb: int = 2,
19
 
20
  conv_channels: int = 64,
21
  conv_kernel_size_input: Tuple[int, int] = (3, 3),
 
81
 
82
  # spectrum
83
  self.spec_bins = spec_bins
84
+ self.erb_bins = erb_bins
85
+ self.min_freq_bins_for_erb = min_freq_bins_for_erb
86
 
87
  # conv
88
  self.conv_channels = conv_channels
toolbox/torchaudio/models/dfnet/conv_stft.py DELETED
@@ -1,148 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/conv_stft.py
5
- """
6
- import numpy as np
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
- from scipy.signal import get_window
11
-
12
-
13
- def init_kernels(nfft: int, win_size: int, hop_size: int, win_type: str = None, inverse=False):
14
- if win_type == "None" or win_type is None:
15
- window = np.ones(win_size)
16
- else:
17
- window = get_window(win_type, win_size, fftbins=True)**0.5
18
-
19
- fourier_basis = np.fft.rfft(np.eye(nfft))[:win_size]
20
- real_kernel = np.real(fourier_basis)
21
- image_kernel = np.imag(fourier_basis)
22
- kernel = np.concatenate([real_kernel, image_kernel], 1).T
23
-
24
- if inverse:
25
- kernel = np.linalg.pinv(kernel).T
26
-
27
- kernel = kernel * window
28
- kernel = kernel[:, None, :]
29
- result = (
30
- torch.from_numpy(kernel.astype(np.float32)),
31
- torch.from_numpy(window[None, :, None].astype(np.float32))
32
- )
33
- return result
34
-
35
-
36
- class ConvSTFT(nn.Module):
37
-
38
- def __init__(self,
39
- nfft: int,
40
- win_size: int,
41
- hop_size: int,
42
- win_type: str = "hamming",
43
- power: int = None,
44
- requires_grad: bool = False):
45
- super(ConvSTFT, self).__init__()
46
-
47
- if nfft is None:
48
- self.nfft = int(2**np.ceil(np.log2(win_size)))
49
- else:
50
- self.nfft = nfft
51
-
52
- kernel, _ = init_kernels(self.nfft, win_size, hop_size, win_type)
53
- self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
54
-
55
- self.win_size = win_size
56
- self.hop_size = hop_size
57
-
58
- self.stride = hop_size
59
- self.dim = self.nfft
60
- self.power = power
61
-
62
- def forward(self, inputs: torch.Tensor):
63
- if inputs.dim() == 2:
64
- inputs = torch.unsqueeze(inputs, 1)
65
-
66
- matrix = F.conv1d(inputs, self.weight, stride=self.stride)
67
- dim = self.dim // 2 + 1
68
- real = matrix[:, :dim, :]
69
- imag = matrix[:, dim:, :]
70
- spec = torch.complex(real, imag)
71
-
72
- if self.power is None:
73
- return spec
74
- elif self.power == 1:
75
- mags = torch.sqrt(real**2 + imag**2)
76
- # phase = torch.atan2(imag, real)
77
- return mags
78
- elif self.power == 2:
79
- power = real**2 + imag**2
80
- return power
81
- else:
82
- raise AssertionError
83
-
84
-
85
- class ConviSTFT(nn.Module):
86
-
87
- def __init__(self,
88
- win_size: int,
89
- hop_size: int,
90
- nfft: int = None,
91
- win_type: str = "hamming",
92
- requires_grad: bool = False):
93
- super(ConviSTFT, self).__init__()
94
- if nfft is None:
95
- self.nfft = int(2**np.ceil(np.log2(win_size)))
96
- else:
97
- self.nfft = nfft
98
-
99
- kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True)
100
- self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
101
-
102
- self.win_size = win_size
103
- self.hop_size = hop_size
104
- self.win_type = win_type
105
-
106
- self.stride = hop_size
107
- self.dim = self.nfft
108
-
109
- self.register_buffer("window", window)
110
- self.register_buffer("enframe", torch.eye(win_size)[:, None, :])
111
-
112
- def forward(self,
113
- inputs: torch.Tensor):
114
- """
115
- :param inputs: torch.Tensor, shape: [b, f, t]
116
- :return:
117
- """
118
- inputs = torch.view_as_real(inputs)
119
- matrix = torch.concat(tensors=[inputs[..., 0], inputs[..., 1]], dim=1)
120
-
121
- waveform = F.conv_transpose1d(matrix, self.weight, stride=self.stride)
122
-
123
- # this is from torch-stft: https://github.com/pseeth/torch-stft
124
- t = self.window.repeat(1, 1, matrix.size(-1))**2
125
- coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
126
- waveform = waveform / (coff + 1e-8)
127
- return waveform
128
-
129
-
130
- def main():
131
- stft = ConvSTFT(nfft=512, win_size=512, hop_size=200, power=None)
132
- istft = ConviSTFT(nfft=512, win_size=512, hop_size=200)
133
-
134
- mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32)
135
-
136
- spec = stft.forward(mixture)
137
- # shape: [batch_size, freq_bins, time_steps]
138
- print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
139
-
140
- waveform = istft.forward(spec)
141
- # shape: [batch_size, channels, num_samples]
142
- print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
143
-
144
- return
145
-
146
-
147
- if __name__ == "__main__":
148
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/dfnet/modeling_dfnet.py CHANGED
@@ -12,8 +12,9 @@ import torchaudio
12
 
13
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
14
  from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
15
- from toolbox.torchaudio.models.dfnet.conv_stft import ConvSTFT, ConviSTFT
16
  from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
 
17
 
18
 
19
  MODEL_FILE = "model.pt"
@@ -225,7 +226,8 @@ class GroupedLinear(nn.Module):
225
  # The better way, but not supported by torchscript
226
  # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G]
227
  x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G]
228
- x = x.flatten(2, 3) # [B, T, H]
 
229
  return x
230
 
231
  def __repr__(self):
@@ -302,7 +304,8 @@ class SqueezedGRU_S(nn.Module):
302
  self.linear_out = nn.Identity()
303
 
304
  def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]:
305
- x = self.linear_in(inputs)
 
306
 
307
  x, h = self.gru.forward(x, h)
308
 
@@ -327,8 +330,8 @@ class Concat(nn.Module):
327
  class Encoder(nn.Module):
328
  def __init__(self, config: DfNetConfig):
329
  super(Encoder, self).__init__()
330
- self.embedding_input_size = config.conv_channels * config.spec_bins // 4
331
- self.embedding_output_size = config.conv_channels * config.spec_bins // 4
332
  self.embedding_hidden_size = config.embedding_hidden_size
333
 
334
  self.spec_conv0 = CausalConv2d(
@@ -423,49 +426,55 @@ class Encoder(nn.Module):
423
  self.lsnr_offset = config.min_local_snr
424
 
425
  def forward(self,
426
- feat_power: torch.Tensor,
427
  feat_spec: torch.Tensor,
428
  hidden_state: torch.Tensor = None,
429
  ):
430
- # feat_power shape: (batch_size, 1, time_steps, spec_dim)
431
- e0 = self.spec_conv0.forward(feat_power)
432
  e1 = self.spec_conv1.forward(e0)
433
  e2 = self.spec_conv2.forward(e1)
434
  e3 = self.spec_conv3.forward(e2)
435
- # e0 shape: [batch_size, channels, time_steps, spec_dim]
436
- # e1 shape: [batch_size, channels, time_steps, spec_dim // 2]
437
- # e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
438
- # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
 
439
 
440
- # feat_spec, shape: (batch_size, 2, time_steps, df_bins)
441
  c0 = self.df_conv0(feat_spec)
442
  c1 = self.df_conv1(c0)
443
- # c0 shape: [batch_size, channels, time_steps, df_bins]
444
- # c1 shape: [batch_size, channels, time_steps, df_bins // 2]
 
445
 
446
  cemb = c1.permute(0, 2, 3, 1)
447
- # cemb shape: [batch_size, time_steps, df_bins // 2, channels]
448
  cemb = cemb.flatten(2)
449
- # cemb shape: [batch_size, time_steps, df_bins // 2 * channels]
450
- cemb = self.df_fc_emb(cemb)
451
- # cemb shape: [batch_size, time_steps, spec_dim // 4 * channels]
 
 
452
 
453
- # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
454
  emb = e3.permute(0, 2, 3, 1)
455
- # emb shape: [batch_size, time_steps, spec_dim // 4, channels]
456
  emb = emb.flatten(2)
457
- # emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
 
458
 
459
  emb = self.combine(emb, cemb)
460
- # if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
461
- # if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
462
 
463
  emb, h = self.emb_gru.forward(emb, hidden_state)
464
- # emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
465
- # h shape: [batch_size, 1, spec_dim]
 
466
 
467
  lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
468
- # lsnr shape: [batch_size, time_steps, 1]
469
 
470
  return e0, e1, e2, e3, emb, c0, lsnr, h
471
 
@@ -477,8 +486,8 @@ class Decoder(nn.Module):
477
  if config.spec_bins % 8 != 0:
478
  raise AssertionError("spec_bins should be divisible by 8")
479
 
480
- self.emb_in_dim = config.conv_channels * config.spec_bins // 4
481
- self.emb_out_dim = config.conv_channels * config.spec_bins // 4
482
  self.emb_hidden_dim = config.decoder_emb_hidden_size
483
 
484
  self.emb_gru = SqueezedGRU_S(
@@ -570,7 +579,7 @@ class Decoder(nn.Module):
570
  b, _, t, f8 = e3.shape
571
 
572
  # emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels]
573
- emb, _ = self.emb_gru(emb)
574
  # emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
575
  emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2)
576
  e3 = self.convt3(self.conv3p(e3) + emb)
@@ -588,7 +597,7 @@ class DfDecoder(nn.Module):
588
  def __init__(self, config: DfNetConfig):
589
  super(DfDecoder, self).__init__()
590
 
591
- self.embedding_input_size = config.conv_channels * config.spec_bins // 4
592
  self.df_decoder_hidden_size = config.df_decoder_hidden_size
593
  self.df_num_layers = config.df_num_layers
594
 
@@ -712,14 +721,14 @@ class Mask(nn.Module):
712
  return mask_pf
713
 
714
  def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
715
- # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
716
 
717
  if not self.training and self.use_post_filter:
718
  mask = self.post_filter(mask)
719
 
720
- # mask shape: [batch_size, 1, time_steps, spec_bins]
721
  mask = mask.unsqueeze(4)
722
- # mask shape: [batch_size, 1, time_steps, spec_bins, 1]
723
  return spec * mask
724
 
725
 
@@ -803,6 +812,13 @@ class DfNet(nn.Module):
803
  self.hop_size = config.hop_size
804
  self.win_type = config.win_type
805
 
 
 
 
 
 
 
 
806
  self.stft = ConvSTFT(
807
  nfft=config.nfft,
808
  win_size=config.win_size,
@@ -867,37 +883,42 @@ class DfNet(nn.Module):
867
  noisy, n_samples = self.signal_prepare(noisy)
868
 
869
  # noisy shape: [b, num_samples_pad]
870
- cmp_spec = self.stft.forward(noisy)
871
- # cmp_spec shape: [b, f, t], torch.complex64
872
- cmp_spec = torch.view_as_real(cmp_spec)
873
- # cmp_spec shape: [b, f, t, 2]
874
- cmp_spec = cmp_spec.permute(0, 3, 1, 2)
875
- # cmp_spec shape: [b, 2, f, t]
876
- cmp_spec = cmp_spec[:, :, :-1, :]
877
- # cmp_spec shape: [b, 2, spec_bins, t]
878
- # n//2+1 -> n//2; 257 -> 256
879
-
880
- spec = torch.unsqueeze(cmp_spec, dim=4)
881
- # spec shape: [b, 2, spec_bins, t, 1]
882
- spec = spec.permute(0, 4, 3, 2, 1)
883
- # spec shape: [b, 1, t, spec_bins, 2]
884
 
885
- feat_power = torch.sum(torch.square(spec), dim=-1)
886
- # feat_power shape: [b, 1, t, spec_bins]
 
 
887
 
888
- feat_spec = torch.transpose(cmp_spec, dim0=2, dim1=3)
889
- # feat_spec shape: [b, 2, t, spec_bins]
890
  feat_spec = feat_spec[..., :self.df_decoder.df_bins]
891
  # feat_spec shape: [b, 2, t, df_bins]
892
 
893
- e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
894
 
895
  mask = self.decoder.forward(emb, e3, e2, e1, e0)
896
- # mask shape: [b, 1, t, spec_bins]
 
 
897
  if torch.any(mask > 1) or torch.any(mask < 0):
898
  raise AssertionError
899
 
900
  spec_m = self.mask.forward(spec, mask)
 
 
 
901
 
902
  # lsnr shape: [b, t, 1]
903
  lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
@@ -907,8 +928,10 @@ class DfNet(nn.Module):
907
  df_coefs = self.df_out_transform(df_coefs)
908
  # df_coefs shape: [b, df_order, t, df_bins, 2]
909
 
910
- spec_e = self.df_op.forward(spec.clone(), df_coefs)
911
- # est_spec shape: [b, 1, t, spec_bins, 2]
 
 
912
 
913
  spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
914
 
@@ -916,14 +939,10 @@ class DfNet(nn.Module):
916
  spec_e = spec_e.permute(0, 2, 1, 3)
917
  # spec_e shape: [b, spec_bins, t, 2]
918
 
919
- mask = torch.squeeze(mask, dim=1)
920
- mask = mask.permute(0, 2, 1)
921
- # mask shape: [b, spec_bins, t]
922
- est_mask = self.mask_transfer(mask)
923
- # est_mask shape: [b, f, t]
924
-
925
  # spec_e shape: [b, spec_bins, t, 2]
926
- est_spec = self.spec_transfer(spec_e)
 
 
927
  # est_spec shape: [b, f, t], torch.complex64
928
 
929
  est_wav = self.istft.forward(est_spec)
@@ -931,33 +950,11 @@ class DfNet(nn.Module):
931
  est_wav = est_wav[:, :n_samples]
932
  # est_wav shape: [b, n_samples]
933
 
934
- return est_spec, est_wav, est_mask, lsnr
 
 
935
 
936
- def spec_transfer(self, spec_e: torch.Tensor) -> torch.Tensor:
937
- # spec_e shape: [b, spec_bins, t, 2]
938
- b, _, t, _ = spec_e.shape
939
- est_spec = torch.complex(
940
- real=torch.concat(tensors=[
941
- spec_e[..., 0],
942
- torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
943
- ], dim=1),
944
- imag=torch.concat(tensors=[
945
- spec_e[..., 1],
946
- torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
947
- ], dim=1),
948
- )
949
- # est_spec shape: [b, f, t]
950
- return est_spec
951
-
952
- def mask_transfer(self, mask: torch.Tensor) -> torch.Tensor:
953
- # mask shape: [b, 256, t]
954
- b, _, t = mask.shape
955
- est_mask = torch.concat(tensors=[
956
- mask,
957
- torch.zeros(size=(b, 1, t), dtype=mask.dtype).to(mask.device)
958
- ], dim=1)
959
- # est_mask shape: [b, 257, t]
960
- return est_mask
961
 
962
  def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
963
  """
 
12
 
13
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
14
  from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
15
+ from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT
16
  from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
17
+ from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands
18
 
19
 
20
  MODEL_FILE = "model.pt"
 
226
  # The better way, but not supported by torchscript
227
  # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G]
228
  x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G]
229
+ x = x.flatten(2, 3)
230
+ # x: [b, t, h]
231
  return x
232
 
233
  def __repr__(self):
 
304
  self.linear_out = nn.Identity()
305
 
306
  def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]:
307
+ # inputs: shape: [b, t, h]
308
+ x = self.linear_in.forward(inputs)
309
 
310
  x, h = self.gru.forward(x, h)
311
 
 
330
  class Encoder(nn.Module):
331
  def __init__(self, config: DfNetConfig):
332
  super(Encoder, self).__init__()
333
+ self.embedding_input_size = config.conv_channels * config.erb_bins // 4
334
+ self.embedding_output_size = config.conv_channels * config.erb_bins // 4
335
  self.embedding_hidden_size = config.embedding_hidden_size
336
 
337
  self.spec_conv0 = CausalConv2d(
 
426
  self.lsnr_offset = config.min_local_snr
427
 
428
  def forward(self,
429
+ feat_erb: torch.Tensor,
430
  feat_spec: torch.Tensor,
431
  hidden_state: torch.Tensor = None,
432
  ):
433
+ # feat_erb shape: (b, 1, t, erb_bins)
434
+ e0 = self.spec_conv0.forward(feat_erb)
435
  e1 = self.spec_conv1.forward(e0)
436
  e2 = self.spec_conv2.forward(e1)
437
  e3 = self.spec_conv3.forward(e2)
438
+ # e0 shape: [b, c, t, erb_bins]
439
+ # e1 shape: [b, c, t, erb_bins // 2]
440
+ # e2 shape: [b, c, t, erb_bins // 4]
441
+ # e3 shape: [b, c, t, erb_bins // 4]
442
+ # e3 shape: [b, 64, t, 32/4=8]
443
 
444
+ # feat_spec, shape: (b, 2, t, df_bins)
445
  c0 = self.df_conv0(feat_spec)
446
  c1 = self.df_conv1(c0)
447
+ # c0 shape: [b, c, t, df_bins]
448
+ # c1 shape: [b, c, t, df_bins // 2]
449
+ # c1 shape: [b, 64, t, 96/2=48]
450
 
451
  cemb = c1.permute(0, 2, 3, 1)
452
+ # cemb shape: [b, t, df_bins // 2, c]
453
  cemb = cemb.flatten(2)
454
+ # cemb shape: [b, t, df_bins // 2 * c]
455
+ # cemb shape: [b, t, 96/2*64=3072]
456
+ cemb = self.df_fc_emb.forward(cemb)
457
+ # cemb shape: [b, t, erb_bins // 4 * c]
458
+ # cemb shape: [b, t, 32/4*64=512]
459
 
460
+ # e3 shape: [b, c, t, erb_bins // 4]
461
  emb = e3.permute(0, 2, 3, 1)
462
+ # emb shape: [b, t, erb_bins // 4, c]
463
  emb = emb.flatten(2)
464
+ # emb shape: [b, t, erb_bins // 4 * c]
465
+ # emb shape: [b, t, 32/4*64=512]
466
 
467
  emb = self.combine(emb, cemb)
468
+ # if concat; emb shape: [b, t, spec_bins // 4 * c * 2]
469
+ # if add; emb shape: [b, t, spec_bins // 4 * c]
470
 
471
  emb, h = self.emb_gru.forward(emb, hidden_state)
472
+
473
+ # emb shape: [b, t, spec_dim // 4 * c]
474
+ # h shape: [b, 1, spec_dim]
475
 
476
  lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
477
+ # lsnr shape: [b, t, 1]
478
 
479
  return e0, e1, e2, e3, emb, c0, lsnr, h
480
 
 
486
  if config.spec_bins % 8 != 0:
487
  raise AssertionError("spec_bins should be divisible by 8")
488
 
489
+ self.emb_in_dim = config.conv_channels * config.erb_bins // 4
490
+ self.emb_out_dim = config.conv_channels * config.erb_bins // 4
491
  self.emb_hidden_dim = config.decoder_emb_hidden_size
492
 
493
  self.emb_gru = SqueezedGRU_S(
 
579
  b, _, t, f8 = e3.shape
580
 
581
  # emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels]
582
+ emb, _ = self.emb_gru.forward(emb)
583
  # emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
584
  emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2)
585
  e3 = self.convt3(self.conv3p(e3) + emb)
 
597
  def __init__(self, config: DfNetConfig):
598
  super(DfDecoder, self).__init__()
599
 
600
+ self.embedding_input_size = config.conv_channels * config.erb_bins // 4
601
  self.df_decoder_hidden_size = config.df_decoder_hidden_size
602
  self.df_num_layers = config.df_num_layers
603
 
 
721
  return mask_pf
722
 
723
  def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
724
+ # spec shape: [b, 1, t, spec_bins, 2]
725
 
726
  if not self.training and self.use_post_filter:
727
  mask = self.post_filter(mask)
728
 
729
+ # mask shape: [b, 1, t, spec_bins]
730
  mask = mask.unsqueeze(4)
731
+ # mask shape: [b, 1, t, spec_bins, 1]
732
  return spec * mask
733
 
734
 
 
812
  self.hop_size = config.hop_size
813
  self.win_type = config.win_type
814
 
815
+ self.erb_bands = ErbBands(
816
+ sample_rate=config.sample_rate,
817
+ nfft=config.nfft,
818
+ erb_bins=config.erb_bins,
819
+ min_freq_bins_for_erb=config.min_freq_bins_for_erb,
820
+ )
821
+
822
  self.stft = ConvSTFT(
823
  nfft=config.nfft,
824
  win_size=config.win_size,
 
883
  noisy, n_samples = self.signal_prepare(noisy)
884
 
885
  # noisy shape: [b, num_samples_pad]
886
+ spec_cmp = self.stft.forward(noisy)
887
+ # spec_complex shape: [b, f, t], torch.complex64
888
+ spec_cmp = torch.transpose(spec_cmp, dim0=1, dim1=2)
889
+ # spec_complex shape: [b, t, f], torch.complex64
890
+ spec_cmp_real = torch.view_as_real(spec_cmp)
891
+ # spec_cmp_real shape: [b, t, f, 2]
892
+ spec_mag = torch.abs(spec_cmp)
893
+ spec_pow = torch.square(spec_mag)
894
+ # shape: [b, t, f]
895
+
896
+ spec = torch.unsqueeze(spec_cmp_real, dim=1)
897
+ # spec shape: [b, 1, t, f, 2]
 
 
898
 
899
+ feat_erb = self.erb_bands.erb_scale(spec_pow, db=True)
900
+ # feat_erb shape: [b, t, erb_bins]
901
+ feat_erb = torch.unsqueeze(feat_erb, dim=1)
902
+ # feat_erb shape: [b, 1, t, erb_bins]
903
 
904
+ feat_spec = spec_cmp_real.permute(0, 3, 1, 2)
905
+ # feat_spec shape: [b, 2, t, f]
906
  feat_spec = feat_spec[..., :self.df_decoder.df_bins]
907
  # feat_spec shape: [b, 2, t, df_bins]
908
 
909
+ e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_erb, feat_spec)
910
 
911
  mask = self.decoder.forward(emb, e3, e2, e1, e0)
912
+ # mask shape: [b, 1, t, erb_bins]
913
+ mask = self.erb_bands.erb_scale_inv(mask)
914
+ # mask shape: [b, 1, t, f]
915
  if torch.any(mask > 1) or torch.any(mask < 0):
916
  raise AssertionError
917
 
918
  spec_m = self.mask.forward(spec, mask)
919
+ # spec_m shape: [b, 1, t, f, 2]
920
+ spec_m = spec_m[:, :, :, :self.config.spec_bins, :]
921
+ # spec_m shape: [b, 1, t, spec_bins, 2]
922
 
923
  # lsnr shape: [b, t, 1]
924
  lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
 
928
  df_coefs = self.df_out_transform(df_coefs)
929
  # df_coefs shape: [b, df_order, t, df_bins, 2]
930
 
931
+ spec_ = spec[:, :, :, :self.config.spec_bins, :]
932
+ # spec shape: [b, 1, t, spec_bins, 2]
933
+ spec_e = self.df_op.forward(spec_, df_coefs)
934
+ # spec_e shape: [b, 1, t, spec_bins, 2]
935
 
936
  spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
937
 
 
939
  spec_e = spec_e.permute(0, 2, 1, 3)
940
  # spec_e shape: [b, spec_bins, t, 2]
941
 
 
 
 
 
 
 
942
  # spec_e shape: [b, spec_bins, t, 2]
943
+ est_spec = torch.complex(real=spec_e[..., 0], imag=spec_e[..., 1])
944
+ # est_spec shape: [b, spec_bins, t], torch.complex64
945
+ est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1)
946
  # est_spec shape: [b, f, t], torch.complex64
947
 
948
  est_wav = self.istft.forward(est_spec)
 
950
  est_wav = est_wav[:, :n_samples]
951
  # est_wav shape: [b, n_samples]
952
 
953
+ est_mask = torch.squeeze(mask, dim=1)
954
+ est_mask = est_mask.permute(0, 2, 1)
955
+ # est_mask shape: [b, f, t]
956
 
957
+ return est_spec, est_wav, est_mask, lsnr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
958
 
959
  def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
960
  """
toolbox/torchaudio/models/frcrn/conv_stft.py CHANGED
@@ -127,8 +127,8 @@ class ConviSTFT(nn.Module):
127
 
128
 
129
  def main():
130
- stft = ConvSTFT(win_size=512, hop_size=200, feature_type="complex")
131
- istft = ConviSTFT(win_size=512, hop_size=200, feature_type="complex")
132
 
133
  mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32)
134
 
 
127
 
128
 
129
  def main():
130
+ stft = ConvSTFT(nfft=512, win_size=512, hop_size=200, feature_type="complex")
131
+ istft = ConviSTFT(nfft=512, win_size=512, hop_size=200, feature_type="complex")
132
 
133
  mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32)
134
 
toolbox/torchaudio/models/{simple_lstm_irm → lstm}/__init__.py RENAMED
File without changes
toolbox/torchaudio/models/lstm/configuration_lstm.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class LstmConfig(PretrainedConfig):
7
+ def __init__(self,
8
+ sample_rate: int = 8000,
9
+ segment_size: int = 32000,
10
+ nfft: int = 512,
11
+ win_size: int = 512,
12
+ hop_size: int = 256,
13
+ win_type: str = "hann",
14
+
15
+ hidden_size: int = 1024,
16
+ num_layers: int = 2,
17
+ dropout: float = 0.2,
18
+
19
+ min_snr_db: float = -10,
20
+ max_snr_db: float = 20,
21
+
22
+ max_epochs: int = 100,
23
+ batch_size: int = 4,
24
+ num_workers: int = 4,
25
+ seed: int = 1234,
26
+
27
+ lr: float = 0.001,
28
+ lr_scheduler: str = "CosineAnnealingLR",
29
+ lr_scheduler_kwargs: dict = None,
30
+
31
+ weight_decay: float = 0.00001,
32
+ clip_grad_norm: float = 10.,
33
+ eval_steps: int = 25000,
34
+
35
+ **kwargs
36
+ ):
37
+ super(LstmConfig, self).__init__(**kwargs)
38
+ self.sample_rate = sample_rate
39
+ self.segment_size = segment_size
40
+ self.nfft = nfft
41
+ self.win_size = win_size
42
+ self.hop_size = hop_size
43
+ self.win_type = win_type
44
+
45
+ self.hidden_size = hidden_size
46
+ self.num_layers = num_layers
47
+ self.dropout = dropout
48
+
49
+ self.min_snr_db = min_snr_db
50
+ self.max_snr_db = max_snr_db
51
+
52
+ self.max_epochs = max_epochs
53
+ self.batch_size = batch_size
54
+ self.num_workers = num_workers
55
+ self.seed = seed
56
+
57
+ self.lr = lr
58
+ self.lr_scheduler = lr_scheduler
59
+ self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
60
+
61
+ self.weight_decay = weight_decay
62
+ self.clip_grad_norm = clip_grad_norm
63
+ self.eval_steps = eval_steps
64
+
65
+
66
+ def main():
67
+ config = LstmConfig()
68
+ config.to_yaml_file("config.yaml")
69
+ return
70
+
71
+
72
+ if __name__ == "__main__":
73
+ main()
toolbox/torchaudio/models/lstm/modeling_lstm.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/haoxiangsnr/IRM-based-Speech-Enhancement-using-LSTM/blob/master/model/lstm_model.py
5
+ """
6
+ import os
7
+ from typing import Optional, Union, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+ import torchaudio
13
+
14
+ from toolbox.torchaudio.models.lstm.configuration_lstm import LstmConfig
15
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
16
+ from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT
17
+
18
+
19
+ MODEL_FILE = "model.pt"
20
+
21
+
22
+ class Transpose(nn.Module):
23
+ def __init__(self, dim0: int, dim1: int):
24
+ super(Transpose, self).__init__()
25
+ self.dim0 = dim0
26
+ self.dim1 = dim1
27
+
28
+ def forward(self, inputs: torch.Tensor):
29
+ inputs = torch.transpose(inputs, dim0=self.dim0, dim1=self.dim1)
30
+ return inputs
31
+
32
+
33
+ class LstmModel(nn.Module):
34
+ def __init__(self,
35
+ nfft: int = 512,
36
+ win_size: int = 512,
37
+ hop_size: int = 256,
38
+ win_type: str = "hann",
39
+ hidden_size=1024,
40
+ num_layers: int = 2,
41
+ batch_first: bool = True,
42
+ dropout: float = 0.2,
43
+ ):
44
+ super(LstmModel, self).__init__()
45
+ self.nfft = nfft
46
+ self.win_size = win_size
47
+ self.hop_size = hop_size
48
+ self.win_type = win_type
49
+
50
+ self.spec_bins = nfft // 2 + 1
51
+ self.hidden_size = hidden_size
52
+
53
+ self.eps = 1e-8
54
+
55
+ self.stft = ConvSTFT(
56
+ nfft=self.nfft,
57
+ win_size=self.win_size,
58
+ hop_size=self.hop_size,
59
+ win_type=self.win_type,
60
+ power=None,
61
+ requires_grad=False
62
+ )
63
+ self.istft = ConviSTFT(
64
+ nfft=self.nfft,
65
+ win_size=self.win_size,
66
+ hop_size=self.hop_size,
67
+ win_type=self.win_type,
68
+ requires_grad=False
69
+ )
70
+
71
+ self.lstm = nn.LSTM(input_size=self.spec_bins,
72
+ hidden_size=hidden_size,
73
+ num_layers=num_layers,
74
+ batch_first=batch_first,
75
+ dropout=dropout,
76
+ )
77
+ self.linear = nn.Linear(in_features=hidden_size, out_features=self.spec_bins)
78
+ self.activation = nn.Sigmoid()
79
+
80
+ def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor:
81
+ if signal.dim() == 2:
82
+ signal = torch.unsqueeze(signal, dim=1)
83
+ _, _, n_samples = signal.shape
84
+ remainder = (n_samples - self.win_size) % self.hop_size
85
+ if remainder > 0:
86
+ n_samples_pad = self.hop_size - remainder
87
+ signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
88
+ return signal, n_samples
89
+
90
+ def forward(self,
91
+ noisy: torch.Tensor,
92
+ h_state: Tuple[torch.Tensor, torch.Tensor] = None,
93
+ ):
94
+ noisy, num_samples = self.signal_prepare(noisy)
95
+ batch_size, _, num_samples_pad = noisy.shape
96
+ # print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}")
97
+
98
+ mag_noisy, pha_noisy = self.mag_pha_stft(noisy)
99
+ # shape: (b, f, t)
100
+ # t = (num_samples - win_size) / hop_size + 1
101
+
102
+ mask, h_state = self.forward_chunk(mag_noisy, h_state)
103
+ # mask shape: (b, f, t)
104
+
105
+ stft_denoise = self.do_mask(mag_noisy, pha_noisy, mask)
106
+ denoise = self.istft.forward(stft_denoise)
107
+ # denoise shape: [b, 1, num_samples_pad]
108
+
109
+ denoise = denoise[:, :, :num_samples]
110
+ # denoise shape: [b, 1, num_samples]
111
+ return denoise, mask, h_state
112
+
113
+ def mag_pha_stft(self, noisy: torch.Tensor):
114
+ # noisy shape: [b, num_samples]
115
+ stft_noisy = self.stft.forward(noisy)
116
+ # stft_noisy shape: [b, f, t], torch.complex64
117
+
118
+ real = torch.real(stft_noisy)
119
+ imag = torch.imag(stft_noisy)
120
+ mag_noisy = torch.sqrt(real ** 2 + imag ** 2)
121
+ pha_noisy = torch.atan2(imag, real)
122
+ # shape: (b, f, t)
123
+ return mag_noisy, pha_noisy
124
+
125
+ def forward_chunk(self,
126
+ mag_noisy: torch.Tensor,
127
+ h_state: Tuple[torch.Tensor, torch.Tensor] = None,
128
+ ):
129
+ # mag_noisy shape: (b, f, t)
130
+ x = torch.transpose(mag_noisy, dim0=2, dim1=1)
131
+ # x shape: (b, t, f)
132
+ x, h_state = self.lstm.forward(x, hx=h_state)
133
+ x = self.linear.forward(x)
134
+ mask = self.activation(x)
135
+ # mask shape: (b, t, f)
136
+ mask = torch.transpose(mask, dim0=2, dim1=1)
137
+ # mask shape: (b, f, t)
138
+ return mask, h_state
139
+
140
+ def do_mask(self,
141
+ mag_noisy: torch.Tensor,
142
+ pha_noisy: torch.Tensor,
143
+ mask: torch.Tensor,
144
+ ):
145
+ # (b, f, t)
146
+ mag_denoise = mag_noisy * mask
147
+ stft_denoise = mag_denoise * torch.exp((1j * pha_noisy))
148
+ return stft_denoise
149
+
150
+
151
+ class LstmPretrainedModel(LstmModel):
152
+ def __init__(self,
153
+ config: LstmConfig,
154
+ ):
155
+ super(LstmPretrainedModel, self).__init__(
156
+ nfft=config.nfft,
157
+ win_size=config.win_size,
158
+ hop_size=config.hop_size,
159
+ win_type=config.win_type,
160
+ hidden_size=config.hidden_size,
161
+ num_layers=config.num_layers,
162
+ dropout=config.dropout,
163
+ )
164
+ self.config = config
165
+
166
+ @classmethod
167
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
168
+ config = LstmConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
169
+
170
+ model = cls(config)
171
+
172
+ if os.path.isdir(pretrained_model_name_or_path):
173
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
174
+ else:
175
+ ckpt_file = pretrained_model_name_or_path
176
+
177
+ with open(ckpt_file, "rb") as f:
178
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
179
+ model.load_state_dict(state_dict, strict=True)
180
+ return model
181
+
182
+ def save_pretrained(self,
183
+ save_directory: Union[str, os.PathLike],
184
+ state_dict: Optional[dict] = None,
185
+ ):
186
+
187
+ model = self
188
+
189
+ if state_dict is None:
190
+ state_dict = model.state_dict()
191
+
192
+ os.makedirs(save_directory, exist_ok=True)
193
+
194
+ # save state dict
195
+ model_file = os.path.join(save_directory, MODEL_FILE)
196
+ torch.save(state_dict, model_file)
197
+
198
+ # save config
199
+ config_file = os.path.join(save_directory, CONFIG_FILE)
200
+ self.config.to_yaml_file(config_file)
201
+ return save_directory
202
+
203
+
204
+ def main():
205
+ config = LstmConfig()
206
+ model = LstmPretrainedModel(config)
207
+ model.eval()
208
+
209
+ noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
210
+ noisy, _ = model.signal_prepare(noisy)
211
+ b, _, num_samples = noisy.shape
212
+ t = (num_samples - config.win_size) / config.hop_size + 1
213
+
214
+ waveform, mask, h_state = model.forward(noisy)
215
+ print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
216
+ print(waveform[:, :, 300: 302])
217
+
218
+ # noisy_pad shape: [b, 1, num_samples_pad]
219
+
220
+ h_state = None
221
+ sub_spec_list = list()
222
+ for i in range(int(t)):
223
+ begin = i * config.hop_size
224
+ end = begin + config.win_size
225
+ sub_noisy = noisy[:, :, begin:end]
226
+ mag_noisy, pha_noisy = model.mag_pha_stft(sub_noisy)
227
+ mask, h_state = model.forward_chunk(mag_noisy, h_state)
228
+ sub_spec = model.do_mask(mag_noisy, pha_noisy, mask)
229
+ sub_spec_list.append(sub_spec)
230
+
231
+ spec = torch.concat(sub_spec_list, dim=2)
232
+
233
+ # 1
234
+ waveform = model.istft.forward(spec)
235
+ waveform = waveform[:, :, :num_samples]
236
+ print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
237
+ print(waveform[:, :, 300: 302])
238
+
239
+ # 2
240
+ waveform_cache = None
241
+ coff_cache = None
242
+ waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32)
243
+ for i in range(int(t)):
244
+ sub_spec = spec[:, :, i:i+1]
245
+ begin = i * config.hop_size
246
+ end = begin + config.win_size - config.hop_size
247
+ sub_waveform, waveform_cache, coff_cache = model.istft.forward_chunk(sub_spec, waveform_cache, coff_cache)
248
+ # end = begin + config.win_size
249
+ # sub_waveform = model.istft.forward(sub_spec)
250
+
251
+ # (b, 1, win_size)
252
+ waveform[:, :, begin:end] = sub_waveform
253
+ print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
254
+ print(waveform[:, :, 300: 302])
255
+
256
+ return
257
+
258
+
259
+ if __name__ == "__main__":
260
+ main()
toolbox/torchaudio/models/lstm/yaml/config.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "lstm"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ segment_size: 32000
6
+ n_fft: 320
7
+ win_size: 320
8
+ hop_size: 160
9
+ win_type: hann
10
+
11
+ # data
12
+ max_snr_db: 20
13
+ min_snr_db: -10
14
+
15
+ # model
16
+ hidden_size: 512
17
+ num_layers: 3
18
+ dropout: 0.1
19
+
20
+ # train
21
+ max_epochs: 100
22
+ batch_size: 32
23
+ num_workers: 4
24
+ seed: 1234
25
+
26
+ lr: 0.001
27
+ lr_scheduler: CosineAnnealingLR
28
+ lr_scheduler_kwargs: {}
29
+
30
+ weight_decay: 0.00001
31
+ clip_grad_norm: 10.0
32
+ eval_steps: 25000
toolbox/torchaudio/models/simple_lstm_irm/configuration_simple_lstm_irm.py DELETED
@@ -1,38 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- from toolbox.torchaudio.configuration_utils import PretrainedConfig
4
-
5
-
6
- class SimpleLstmIRMConfig(PretrainedConfig):
7
- def __init__(self,
8
- sample_rate: int,
9
- n_fft: int,
10
- win_length: int,
11
- hop_length: int,
12
-
13
- num_bins: int,
14
- hidden_size: int,
15
- num_layers: int,
16
- batch_first: bool,
17
- dropout: float,
18
- lookback: int,
19
- lookahead: int,
20
- **kwargs
21
- ):
22
- super(SimpleLstmIRMConfig, self).__init__(**kwargs)
23
- self.sample_rate = sample_rate
24
- self.n_fft = n_fft
25
- self.win_length = win_length
26
- self.hop_length = hop_length
27
-
28
- self.num_bins = num_bins
29
- self.hidden_size = hidden_size
30
- self.num_layers = num_layers
31
- self.batch_first = batch_first
32
- self.dropout = dropout
33
- self.lookback = lookback
34
- self.lookahead = lookahead
35
-
36
-
37
- if __name__ == "__main__":
38
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/simple_lstm_irm/modeling_simple_lstm_irm.py DELETED
@@ -1,133 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- https://github.com/haoxiangsnr/IRM-based-Speech-Enhancement-using-LSTM/blob/master/model/lstm_model.py
5
- """
6
- import os
7
- from typing import Optional, Union
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torchaudio
12
-
13
- from toolbox.torchaudio.models.simple_lstm_irm.configuration_simple_lstm_irm import SimpleLstmIRMConfig
14
- from toolbox.torchaudio.configuration_utils import CONFIG_FILE
15
-
16
-
17
- MODEL_FILE = "model.pt"
18
-
19
-
20
- class Transpose(nn.Module):
21
- def __init__(self, dim0: int, dim1: int):
22
- super(Transpose, self).__init__()
23
- self.dim0 = dim0
24
- self.dim1 = dim1
25
-
26
- def forward(self, inputs: torch.Tensor):
27
- inputs = torch.transpose(inputs, dim0=self.dim0, dim1=self.dim1)
28
- return inputs
29
-
30
-
31
- class SimpleLstmIRM(nn.Module):
32
- """
33
- Ideal ratio mask estimator:
34
-
35
- """
36
-
37
- def __init__(self, num_bins=257, hidden_size=1024,
38
- num_layers: int = 2,
39
- batch_first: bool = True,
40
- dropout: float = 0.4,
41
- ):
42
- super(SimpleLstmIRM, self).__init__()
43
- self.num_bins = num_bins
44
- self.hidden_size = hidden_size
45
-
46
- self.lstm = nn.LSTM(input_size=num_bins,
47
- hidden_size=hidden_size,
48
- num_layers=num_layers,
49
- batch_first=batch_first,
50
- dropout=dropout,
51
- )
52
- self.linear = nn.Linear(in_features=hidden_size, out_features=num_bins)
53
- self.activation = nn.Sigmoid()
54
-
55
- def forward(self, spec: torch.Tensor):
56
- # spec shape: (batch_size, num_bins, time_steps)
57
- spec = torch.transpose(spec, dim0=2, dim1=1)
58
- # frame_spec shape: (batch_size, time_steps, num_bins)
59
- spec, _ = self.lstm(spec)
60
- spec = self.linear(spec)
61
- mask = self.activation(spec)
62
- return mask
63
-
64
-
65
- class SimpleLstmIRMPretrainedModel(SimpleLstmIRM):
66
- def __init__(self,
67
- config: SimpleLstmIRMConfig,
68
- ):
69
- super(SimpleLstmIRMPretrainedModel, self).__init__(
70
- num_bins=config.num_bins,
71
- hidden_size=config.hidden_size,
72
- )
73
- self.config = config
74
-
75
- @classmethod
76
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
77
- config = SimpleLstmIRMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
78
-
79
- model = cls(config)
80
-
81
- if os.path.isdir(pretrained_model_name_or_path):
82
- ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
83
- else:
84
- ckpt_file = pretrained_model_name_or_path
85
-
86
- with open(ckpt_file, "rb") as f:
87
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
88
- model.load_state_dict(state_dict, strict=True)
89
- return model
90
-
91
- def save_pretrained(self,
92
- save_directory: Union[str, os.PathLike],
93
- state_dict: Optional[dict] = None,
94
- ):
95
-
96
- model = self
97
-
98
- if state_dict is None:
99
- state_dict = model.state_dict()
100
-
101
- os.makedirs(save_directory, exist_ok=True)
102
-
103
- # save state dict
104
- model_file = os.path.join(save_directory, MODEL_FILE)
105
- torch.save(state_dict, model_file)
106
-
107
- # save config
108
- config_file = os.path.join(save_directory, CONFIG_FILE)
109
- self.config.to_yaml_file(config_file)
110
- return save_directory
111
-
112
-
113
- def main():
114
- transformer = torchaudio.transforms.Spectrogram(
115
- n_fft=512,
116
- win_length=200,
117
- hop_length=80,
118
- window_fn=torch.hamming_window,
119
- )
120
-
121
- model = SimpleLstmIRM()
122
-
123
- inputs = torch.randn(size=(1, 1600), dtype=torch.float32)
124
- spec = transformer.forward(inputs)
125
-
126
- output = model.forward(spec)
127
- print(output.shape)
128
- print(output)
129
- return
130
-
131
-
132
- if __name__ == '__main__':
133
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/simple_lstm_irm/yaml/config.yaml DELETED
@@ -1,14 +0,0 @@
1
- model_name: "simple_lstm_irm"
2
-
3
- # spec
4
- sample_rate: 8000
5
- n_fft: 320
6
- win_length: 320
7
- hop_length: 80
8
-
9
- # model
10
- num_bins: 161
11
- hidden_size: 512
12
- num_layers: 3
13
- batch_first: true
14
- dropout: 0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/modules/conv_stft.py CHANGED
@@ -59,11 +59,11 @@ class ConvSTFT(nn.Module):
59
  self.dim = self.nfft
60
  self.power = power
61
 
62
- def forward(self, inputs: torch.Tensor):
63
- if inputs.dim() == 2:
64
- inputs = torch.unsqueeze(inputs, 1)
65
 
66
- matrix = F.conv1d(inputs, self.weight, stride=self.stride)
67
  dim = self.dim // 2 + 1
68
  real = matrix[:, :dim, :]
69
  imag = matrix[:, dim:, :]
@@ -99,6 +99,8 @@ class ConviSTFT(nn.Module):
99
 
100
  kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True)
101
  self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
 
 
102
 
103
  self.win_size = win_size
104
  self.hop_size = hop_size
@@ -109,41 +111,158 @@ class ConviSTFT(nn.Module):
109
 
110
  self.register_buffer("window", window)
111
  self.register_buffer("enframe", torch.eye(win_size)[:, None, :])
 
 
112
 
113
  def forward(self,
114
- inputs: torch.Tensor):
115
  """
116
- :param inputs: torch.Tensor, shape: [b, f, t]
 
 
 
 
117
  :return:
118
  """
119
- inputs = torch.view_as_real(inputs)
120
- matrix = torch.concat(tensors=[inputs[..., 0], inputs[..., 1]], dim=1)
 
 
121
 
122
  waveform = F.conv_transpose1d(matrix, self.weight, stride=self.stride)
 
123
 
124
  # this is from torch-stft: https://github.com/pseeth/torch-stft
125
  t = self.window.repeat(1, 1, matrix.size(-1))**2
 
126
  coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
 
127
  waveform = waveform / (coff + 1e-8)
 
128
  return waveform
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  def main():
132
- stft = ConvSTFT(nfft=512, win_size=512, hop_size=200, power=None)
133
- istft = ConviSTFT(nfft=512, win_size=512, hop_size=200)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- mixture = torch.rand(size=(1, 8000*40), dtype=torch.float32)
 
 
 
 
 
 
 
 
 
 
136
 
137
  spec = stft.forward(mixture)
138
- # shape: [batch_size, freq_bins, time_steps]
 
 
 
139
  print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
140
 
141
  waveform = istft.forward(spec)
142
  # shape: [batch_size, channels, num_samples]
143
  print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  return
146
 
147
 
148
  if __name__ == "__main__":
149
- main()
 
59
  self.dim = self.nfft
60
  self.power = power
61
 
62
+ def forward(self, waveform: torch.Tensor):
63
+ if waveform.dim() == 2:
64
+ waveform = torch.unsqueeze(waveform, 1)
65
 
66
+ matrix = F.conv1d(waveform, self.weight, stride=self.stride)
67
  dim = self.dim // 2 + 1
68
  real = matrix[:, :dim, :]
69
  imag = matrix[:, dim:, :]
 
99
 
100
  kernel, window = init_kernels(self.nfft, win_size, hop_size, win_type, inverse=True)
101
  self.weight = nn.Parameter(kernel, requires_grad=requires_grad)
102
+ # weight shape: [f*2, 1, nfft]
103
+ # f = nfft // 2 + 1
104
 
105
  self.win_size = win_size
106
  self.hop_size = hop_size
 
111
 
112
  self.register_buffer("window", window)
113
  self.register_buffer("enframe", torch.eye(win_size)[:, None, :])
114
+ # window shape: [1, nfft, 1]
115
+ # enframe shape: [nfft, 1, nfft]
116
 
117
  def forward(self,
118
+ spec: torch.Tensor):
119
  """
120
+ self.weight shape: [f*2, 1, win_size]
121
+ self.window shape: [1, win_size, 1]
122
+ self.enframe shape: [win_size, 1, win_size]
123
+
124
+ :param spec: torch.Tensor, shape: [b, f, t, 2]
125
  :return:
126
  """
127
+ spec = torch.view_as_real(spec)
128
+ # spec shape: [b, f, t, 2]
129
+ matrix = torch.concat(tensors=[spec[..., 0], spec[..., 1]], dim=1)
130
+ # matrix shape: [b, f*2, t]
131
 
132
  waveform = F.conv_transpose1d(matrix, self.weight, stride=self.stride)
133
+ # waveform shape: [b, 1, num_samples]
134
 
135
  # this is from torch-stft: https://github.com/pseeth/torch-stft
136
  t = self.window.repeat(1, 1, matrix.size(-1))**2
137
+ # t shape: [1, win_size, t]
138
  coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
139
+ # coff shape: [1, 1, num_samples]
140
  waveform = waveform / (coff + 1e-8)
141
+ # waveform = waveform / coff
142
  return waveform
143
 
144
+ def forward_chunk(self,
145
+ spec: torch.Tensor,
146
+ waveform_cache: torch.Tensor = None,
147
+ coff_cache: torch.Tensor = None,
148
+ ):
149
+ """
150
+ :param spec: shape: [b, f, t]
151
+ :param waveform_cache: shape: [b, 1, win_size - hop_size]
152
+ :param coff_cache: shape: [b, 1, win_size - hop_size]
153
+ :return:
154
+ """
155
+ spec = torch.view_as_real(spec)
156
+ matrix = torch.concat(tensors=[spec[..., 0], spec[..., 1]], dim=1)
157
+
158
+ waveform_current = F.conv_transpose1d(matrix, self.weight, stride=self.stride)
159
+
160
+ t = self.window.repeat(1, 1, matrix.size(-1))**2
161
+ coff_current = F.conv_transpose1d(t, self.enframe, stride=self.stride)
162
+
163
+ overlap_size = self.win_size - self.hop_size
164
+
165
+ if waveform_cache is not None:
166
+ waveform_overlap = waveform_current[:, :, :overlap_size] + waveform_cache
167
+ waveform_non_overlap = waveform_current[:, :, overlap_size:-self.hop_size]
168
+ waveform_output = torch.cat(tensors=[waveform_overlap, waveform_non_overlap], dim=-1)
169
+ new_waveform_cache = waveform_current[:, :, -self.hop_size:]
170
+ else:
171
+ waveform_output = waveform_current[:, :, :-self.hop_size]
172
+ new_waveform_cache = waveform_current[:, :, -self.hop_size:]
173
+
174
+ if coff_cache is not None:
175
+ coff_overlap = coff_current[:, :, :overlap_size] + coff_cache
176
+ coff_non_overlap = coff_current[:, :, overlap_size:-self.hop_size]
177
+ coff_output = torch.cat(tensors=[coff_overlap, coff_non_overlap], dim=-1)
178
+ new_coff_cache = coff_current[:, :, -self.hop_size:]
179
+ else:
180
+ coff_output = coff_current[:, :, :-self.hop_size]
181
+ new_coff_cache = coff_current[:, :, -self.hop_size:]
182
+
183
+ waveform_output = waveform_output / (coff_output + 1e-8)
184
+ return waveform_output, new_waveform_cache, new_coff_cache
185
+
186
 
187
  def main():
188
+ nfft = 512
189
+ win_size = 512
190
+ hop_size = 256
191
+
192
+ stft = ConvSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size, power=None)
193
+ istft = ConviSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size)
194
+
195
+ mixture = torch.rand(size=(1, 16000), dtype=torch.float32)
196
+ b, num_samples = mixture.shape
197
+ t = (num_samples - win_size) / hop_size + 1
198
+
199
+ spec = stft.forward(mixture)
200
+ b, f, t = spec.shape
201
+
202
+ # 如果 spec 是由 stft 变换得来的,以下两种 waveform 还���方法就是一致的,否则还原出的 waveform 会有差异。
203
+ # spec = spec + 0.01 * torch.randn(size=(1, nfft//2+1, t), dtype=torch.float32)
204
+ print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
205
+
206
+ waveform = istft.forward(spec)
207
+ # shape: [batch_size, channels, num_samples]
208
+ print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
209
+ print(waveform[:, :, 300: 302])
210
+
211
+ waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32)
212
+ for i in range(int(t)):
213
+ begin = i * hop_size
214
+ end = begin + win_size
215
+ sub_spec = spec[:, :, i:i+1]
216
+ sub_waveform = istft.forward(sub_spec)
217
+ # (b, 1, win_size)
218
+ waveform[:, :, begin:end] = sub_waveform
219
+ print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
220
+ print(waveform[:, :, 300: 302])
221
+
222
+ return
223
 
224
+
225
+ def main2():
226
+ nfft = 512
227
+ win_size = 512
228
+ hop_size = 256
229
+
230
+ stft = ConvSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size, power=None)
231
+ istft = ConviSTFT(nfft=nfft, win_size=win_size, hop_size=hop_size)
232
+
233
+ mixture = torch.rand(size=(1, 16128), dtype=torch.float32)
234
+ b, num_samples = mixture.shape
235
 
236
  spec = stft.forward(mixture)
237
+ b, f, t = spec.shape
238
+
239
+ # 如果 spec 是由 stft 变换得来的,以下两种 waveform 还原方法就是一致的,否则还原出的 waveform 会有差异。
240
+ spec = spec + 0.01 * torch.randn(size=(1, nfft//2+1, t), dtype=torch.float32)
241
  print(f"spec.shape: {spec.shape}, spec.dtype: {spec.dtype}")
242
 
243
  waveform = istft.forward(spec)
244
  # shape: [batch_size, channels, num_samples]
245
  print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
246
+ print(waveform[:, :, 300: 302])
247
+
248
+ waveform_cache = None
249
+ coff_cache = None
250
+ waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32)
251
+ for i in range(int(t)):
252
+ sub_spec = spec[:, :, i:i+1]
253
+ begin = i * hop_size
254
+
255
+ end = begin + win_size - hop_size
256
+ sub_waveform, waveform_cache, coff_cache = istft.forward_chunk(sub_spec, waveform_cache, coff_cache)
257
+ # end = begin + win_size
258
+ # sub_waveform = istft.forward(sub_spec)
259
+
260
+ waveform[:, :, begin:end] = sub_waveform
261
+ print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
262
+ print(waveform[:, :, 300: 302])
263
 
264
  return
265
 
266
 
267
  if __name__ == "__main__":
268
+ main2()
toolbox/torchaudio/modules/freq_bands/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/modules/freq_bands/erb_bands.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import math
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class ErbBandsNumpy(object):
11
+
12
+ @staticmethod
13
+ def freq2erb(freq_hz: float) -> float:
14
+ """
15
+ https://www.cnblogs.com/LXP-Never/p/16011229.html
16
+ 1 / (24.7 * 9.265) = 0.00436976
17
+ """
18
+ return 9.265 * math.log(freq_hz / (24.7 * 9.265) + 1)
19
+
20
+ @staticmethod
21
+ def erb2freq(n_erb: float) -> float:
22
+ return 24.7 * 9.265 * (math.exp(n_erb / 9.265) - 1)
23
+
24
+ @classmethod
25
+ def get_erb_widths(cls, sample_rate: int, nfft: int, erb_bins: int, min_freq_bins_for_erb: int) -> np.ndarray:
26
+ """
27
+ https://github.com/Rikorose/DeepFilterNet/blob/main/libDF/src/lib.rs
28
+ :param sample_rate:
29
+ :param nfft:
30
+ :param erb_bins: erb (Equivalent Rectangular Bandwidth) 等效矩形带宽的通道数.
31
+ :param min_freq_bins_for_erb: Minimum number of frequency bands per erb band
32
+ :return:
33
+ """
34
+ nyq_freq = sample_rate / 2.
35
+ freq_width: float = sample_rate / nfft
36
+
37
+ min_erb: float = cls.freq2erb(0.)
38
+ max_erb: float = cls.freq2erb(nyq_freq)
39
+
40
+ erb = [0] * erb_bins
41
+ step = (max_erb - min_erb) / erb_bins
42
+
43
+ prev_freq_bin = 0
44
+ freq_over = 0
45
+ for i in range(1, erb_bins + 1):
46
+ f = cls.erb2freq(min_erb + i * step)
47
+ freq_bin = int(round(f / freq_width))
48
+ freq_bins = freq_bin - prev_freq_bin - freq_over
49
+
50
+ if freq_bins < min_freq_bins_for_erb:
51
+ freq_over = min_freq_bins_for_erb - freq_bins
52
+ freq_bins = min_freq_bins_for_erb
53
+ else:
54
+ freq_over = 0
55
+ erb[i - 1] = freq_bins
56
+ prev_freq_bin = freq_bin
57
+
58
+ erb[erb_bins - 1] += 1
59
+ too_large = sum(erb) - (nfft / 2 + 1)
60
+ if too_large > 0:
61
+ erb[erb_bins - 1] -= too_large
62
+ return np.array(erb, dtype=np.uint64)
63
+
64
+ @staticmethod
65
+ def get_erb_filter_bank(erb_widths: np.ndarray,
66
+ normalized: bool = True,
67
+ inverse: bool = False,
68
+ ):
69
+ num_freq_bins = int(np.sum(erb_widths))
70
+ num_erb_bins = len(erb_widths)
71
+
72
+ fb: np.ndarray = np.zeros(shape=(num_freq_bins, num_erb_bins))
73
+
74
+ points = np.cumsum([0] + erb_widths.tolist()).astype(int)[:-1]
75
+ for i, (b, w) in enumerate(zip(points.tolist(), erb_widths.tolist())):
76
+ fb[b: b + w, i] = 1
77
+
78
+ if inverse:
79
+ fb = fb.T
80
+ if not normalized:
81
+ fb /= np.sum(fb, axis=1, keepdims=True)
82
+ else:
83
+ if normalized:
84
+ fb /= np.sum(fb, axis=0)
85
+ return fb
86
+
87
+ @staticmethod
88
+ def spec2erb(spec: np.ndarray, erb_fb: np.ndarray, db: bool = True):
89
+ """
90
+ ERB filterbank and transform to decibel scale.
91
+
92
+ :param spec: Spectrum of shape [B, C, T, F].
93
+ :param erb_fb: ERB filterbank array of shape [B] containing the ERB widths,
94
+ where B are the number of ERB bins.
95
+ :param db: Whether to transform the output into decibel scale. Defaults to `True`.
96
+ :return:
97
+ """
98
+ # complex spec to power spec. (real * real + image * image)
99
+ spec_ = np.abs(spec) ** 2
100
+
101
+ # spec to erb feature.
102
+ erb_feat = np.matmul(spec_, erb_fb)
103
+
104
+ if db:
105
+ erb_feat = 10 * np.log10(erb_feat + 1e-10)
106
+
107
+ erb_feat = np.array(erb_feat, dtype=np.float32)
108
+ return erb_feat
109
+
110
+
111
+ class ErbBands(nn.Module):
112
+ def __init__(self,
113
+ sample_rate: int = 8000,
114
+ nfft: int = 512,
115
+ erb_bins: int = 32,
116
+ min_freq_bins_for_erb: int = 2,
117
+ ):
118
+ super().__init__()
119
+ self.sample_rate = sample_rate
120
+ self.nfft = nfft
121
+ self.erb_bins = erb_bins
122
+ self.min_freq_bins_for_erb = min_freq_bins_for_erb
123
+
124
+ erb_fb, erb_fb_inv = self.init_erb_fb()
125
+ self.erb_fb = torch.tensor(erb_fb, dtype=torch.float32, requires_grad=False)
126
+ self.erb_fb_inv = torch.tensor(erb_fb_inv, dtype=torch.float32, requires_grad=False)
127
+
128
+ def init_erb_fb(self):
129
+ erb_widths = ErbBandsNumpy.get_erb_widths(
130
+ sample_rate=self.sample_rate,
131
+ nfft=self.nfft,
132
+ erb_bins=self.erb_bins,
133
+ min_freq_bins_for_erb=self.min_freq_bins_for_erb,
134
+ )
135
+ erb_fb = ErbBandsNumpy.get_erb_filter_bank(
136
+ erb_widths=erb_widths,
137
+ normalized=True,
138
+ inverse=False,
139
+ )
140
+ erb_fb_inv = ErbBandsNumpy.get_erb_filter_bank(
141
+ erb_widths=erb_widths,
142
+ normalized=True,
143
+ inverse=True,
144
+ )
145
+ return erb_fb, erb_fb_inv
146
+
147
+ def erb_scale(self, spec: torch.Tensor, db: bool = True):
148
+ spec_erb = torch.matmul(spec, self.erb_fb)
149
+ if db:
150
+ spec_erb = 10 * torch.log10(spec_erb + 1e-10)
151
+ return spec_erb
152
+
153
+ def erb_scale_inv(self, spec_erb: torch.Tensor):
154
+ spec = torch.matmul(spec_erb, self.erb_fb_inv)
155
+ return spec
156
+
157
+
158
+ def main():
159
+
160
+ erb_bands = ErbBands()
161
+
162
+ spec = torch.randn(size=(2, 199, 257), dtype=torch.float32)
163
+ spec_erb = erb_bands.erb_scale(spec)
164
+ print(spec_erb.shape)
165
+
166
+ spec = erb_bands.erb_scale_inv(spec_erb)
167
+ print(spec.shape)
168
+
169
+ return
170
+
171
+
172
+ if __name__ == "__main__":
173
+ main()