HoneyTian commited on
Commit
1b032b9
·
1 Parent(s): d791cee
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -195,7 +195,7 @@ def main():
195
  step_idx = int(step_idx)
196
  if step_idx > last_step_idx:
197
  last_step_idx = step_idx
198
- last_epoch = 2
199
 
200
  if last_step_idx != -1:
201
  logger.info(f"resume from steps-{last_step_idx}.")
@@ -299,7 +299,8 @@ def main():
299
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
300
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
301
  # loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
302
- loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
 
303
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
304
  logger.info(f"find nan or inf in loss.")
305
  continue
 
195
  step_idx = int(step_idx)
196
  if step_idx > last_step_idx:
197
  last_step_idx = step_idx
198
+ last_epoch = 1
199
 
200
  if last_step_idx != -1:
201
  logger.info(f"resume from steps-{last_step_idx}.")
 
299
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
300
  # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
301
  # loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
302
+ # loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
303
+ loss = 0.2 * ae_loss + 0.2 * neg_si_snr_loss + 1.0 * mr_stft_loss + 0.3 * neg_stoi_loss + 0.5 * pesq_loss
304
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
305
  logger.info(f"find nan or inf in loss.")
306
  continue
toolbox/torch/utils/data/dataset/denoise_jsonl_dataset.py CHANGED
@@ -68,6 +68,7 @@ class DenoiseJsonlDataset(IterableDataset):
68
  yield self.convert_sample(sample)
69
 
70
  def iterable_source(self):
 
71
  with open(self.jsonl_file, "r", encoding="utf-8") as f:
72
  for row in f:
73
  row = json.loads(row)
@@ -99,7 +100,11 @@ class DenoiseJsonlDataset(IterableDataset):
99
 
100
  "snr_db": snr_db,
101
  }
 
 
 
102
  yield sample
 
103
 
104
  def convert_sample(self, sample: dict):
105
  noise_filename = sample["noise_filename"]
 
68
  yield self.convert_sample(sample)
69
 
70
  def iterable_source(self):
71
+ last_sample = None
72
  with open(self.jsonl_file, "r", encoding="utf-8") as f:
73
  for row in f:
74
  row = json.loads(row)
 
100
 
101
  "snr_db": snr_db,
102
  }
103
+ if last_sample is None:
104
+ last_sample = sample
105
+ continue
106
  yield sample
107
+ yield last_sample
108
 
109
  def convert_sample(self, sample: dict):
110
  noise_filename = sample["noise_filename"]
toolbox/torchaudio/models/frcrn/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/models/frcrn/modeling_frcrn.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://arxiv.org/abs/2206.07293
5
+ """
6
+ from modelscope.models.audio.ans.frcrn import FRCRN
7
+
8
+
9
+ if __name__ == "__main__":
10
+ pass