#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://zhuanlan.zhihu.com/p/627039860 """ import torch import torch.nn as nn from torch_stoi import NegSTOILoss as TorchNegSTOILoss from torch_pesq import PesqLoss as TorchPesqLoss class PMSQELoss(object): """ A Deep Learning Loss Function based on the Perceptual Evaluation of the Speech Quality https://sigmat.ugr.es/PMSQE/ On Loss Functions for Supervised Monaural Time-Domain Speech Enhancement https://arxiv.org/abs/1909.01019 https://github.com/asteroid-team/asteroid/blob/master/asteroid/losses/pmsqe.py """ class NegSTOILoss(nn.Module): """ STOI短时客观可懂度(Short-Time Objective Intelligibility), 通过计算语音信号的时域和频域特征之间的相关性来预测语音的可理解度, 范围从0到1,分数越高可懂度越高。 它适用于评估噪声环境下的语音可懂度改善效果。 https://github.com/mpariente/pytorch_stoi https://github.com/mpariente/pystoi https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/loss/stoi_loss.py """ def __init__(self, sample_rate: int, reduction: str = "mean", ): super(NegSTOILoss, self).__init__() self.loss_fn = TorchNegSTOILoss(sample_rate=sample_rate) self.reduction = reduction if reduction not in ("sum", "mean"): raise AssertionError(f"param reduction must be sum or mean.") def forward(self, denoise: torch.Tensor, clean: torch.Tensor): batch_loss = self.loss_fn.forward(denoise, clean) if self.reduction == "mean": loss = torch.mean(batch_loss) elif self.reduction == "sum": loss = torch.sum(batch_loss) else: raise AssertionError return loss class PesqLoss(nn.Module): def __init__(self, factor: float, sample_rate: int = 48000, nbarks: int = 49, win_length: int = 512, n_fft: int = 512, hop_length: int = 256, reduction: str = "mean", ): super(PesqLoss, self).__init__() self.factor = factor self.sample_rate = sample_rate self.nbarks = nbarks self.win_length = win_length self.n_fft = n_fft self.hop_length = hop_length self.reduction = reduction self.loss_fn = TorchPesqLoss( factor=factor, sample_rate=sample_rate, nbarks=nbarks, win_length=win_length, n_fft=n_fft, hop_length=hop_length, ) def forward(self, denoise: torch.Tensor, clean: torch.Tensor): batch_loss = self.loss_fn.forward(clean, denoise) # mask = ~(torch.isnan(batch_loss) | torch.isinf(batch_loss)) # batch_loss = batch_loss[mask] # if len(batch_loss) == 0: # raise AssertionError if self.reduction == "mean": loss = torch.mean(batch_loss) elif self.reduction == "sum": loss = torch.sum(batch_loss) else: raise AssertionError return loss def main(): sample_rate = 16000 loss_func = NegSTOILoss( sample_rate=sample_rate, reduction="mean", ) denoise = torch.randn(2, sample_rate) clean = torch.randn(2, sample_rate) loss_batch = loss_func.forward(denoise, clean) print(loss_batch) return if __name__ == "__main__": main()