Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -195,7 +195,7 @@ def main():
|
|
195 |
step_idx = int(step_idx)
|
196 |
if step_idx > last_step_idx:
|
197 |
last_step_idx = step_idx
|
198 |
-
last_epoch =
|
199 |
|
200 |
if last_step_idx != -1:
|
201 |
logger.info(f"resume from steps-{last_step_idx}.")
|
@@ -299,7 +299,8 @@ def main():
|
|
299 |
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
|
300 |
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
|
301 |
# loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
|
302 |
-
loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
|
|
|
303 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
304 |
logger.info(f"find nan or inf in loss.")
|
305 |
continue
|
|
|
195 |
step_idx = int(step_idx)
|
196 |
if step_idx > last_step_idx:
|
197 |
last_step_idx = step_idx
|
198 |
+
last_epoch = 1
|
199 |
|
200 |
if last_step_idx != -1:
|
201 |
logger.info(f"resume from steps-{last_step_idx}.")
|
|
|
299 |
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
|
300 |
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
|
301 |
# loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
|
302 |
+
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
|
303 |
+
loss = 0.2 * ae_loss + 0.2 * neg_si_snr_loss + 1.0 * mr_stft_loss + 0.3 * neg_stoi_loss + 0.5 * pesq_loss
|
304 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
305 |
logger.info(f"find nan or inf in loss.")
|
306 |
continue
|
toolbox/torch/utils/data/dataset/denoise_jsonl_dataset.py
CHANGED
@@ -68,6 +68,7 @@ class DenoiseJsonlDataset(IterableDataset):
|
|
68 |
yield self.convert_sample(sample)
|
69 |
|
70 |
def iterable_source(self):
|
|
|
71 |
with open(self.jsonl_file, "r", encoding="utf-8") as f:
|
72 |
for row in f:
|
73 |
row = json.loads(row)
|
@@ -99,7 +100,11 @@ class DenoiseJsonlDataset(IterableDataset):
|
|
99 |
|
100 |
"snr_db": snr_db,
|
101 |
}
|
|
|
|
|
|
|
102 |
yield sample
|
|
|
103 |
|
104 |
def convert_sample(self, sample: dict):
|
105 |
noise_filename = sample["noise_filename"]
|
|
|
68 |
yield self.convert_sample(sample)
|
69 |
|
70 |
def iterable_source(self):
|
71 |
+
last_sample = None
|
72 |
with open(self.jsonl_file, "r", encoding="utf-8") as f:
|
73 |
for row in f:
|
74 |
row = json.loads(row)
|
|
|
100 |
|
101 |
"snr_db": snr_db,
|
102 |
}
|
103 |
+
if last_sample is None:
|
104 |
+
last_sample = sample
|
105 |
+
continue
|
106 |
yield sample
|
107 |
+
yield last_sample
|
108 |
|
109 |
def convert_sample(self, sample: dict):
|
110 |
noise_filename = sample["noise_filename"]
|
toolbox/torchaudio/models/frcrn/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
pass
|
toolbox/torchaudio/models/frcrn/modeling_frcrn.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
https://arxiv.org/abs/2206.07293
|
5 |
+
"""
|
6 |
+
from modelscope.models.audio.ans.frcrn import FRCRN
|
7 |
+
|
8 |
+
|
9 |
+
if __name__ == "__main__":
|
10 |
+
pass
|