#!/usr/bin/python3 # -*- coding: utf-8 -*- from typing import List import torch import torch.nn as nn from torch.nn import functional as F class CIRMLoss(nn.Module): def __init__(self, n_fft: int = 512, win_size: int = 512, hop_size: int = 256, center: bool = True, eps: float = 1e-8, reduction: str = "mean", ): super(CIRMLoss, self).__init__() self.n_fft = n_fft self.win_size = win_size self.hop_size = hop_size self.center = center self.eps = eps self.reduction = reduction self.window = nn.Parameter(torch.hann_window(win_size), requires_grad=False) if reduction not in ("sum", "mean"): raise AssertionError(f"param reduction must be sum or mean.") def forward(self, clean: torch.Tensor, noisy: torch.Tensor, mask_real: torch.Tensor, mask_imag: torch.Tensor): """ :param clean: waveform :param noisy: waveform :param mask_real: shape: [b, f, t] :param mask_imag: shape: [b, f, t] :return: """ if noisy.shape != clean.shape: raise AssertionError("Input signals must have the same shape") # clean_stft, noisy_stft shape: [b, f, t] clean_stft = torch.stft( clean, n_fft=self.n_fft, win_length=self.win_size, hop_length=self.hop_size, window=self.window, center=self.center, pad_mode="reflect", normalized=False, return_complex=True ) noisy_stft = torch.stft( noisy, n_fft=self.n_fft, win_length=self.win_size, hop_length=self.hop_size, window=self.window, center=self.center, pad_mode="reflect", normalized=False, return_complex=True ) # [b, f, t] clean_stft_spec_real = torch.real(clean_stft) clean_stft_spec_imag = torch.imag(clean_stft) noisy_stft_spec_real = torch.real(noisy_stft) noisy_stft_spec_imag = torch.imag(noisy_stft) noisy_power = noisy_stft_spec_real ** 2 + noisy_stft_spec_imag ** 2 sr = clean_stft_spec_real yr = noisy_stft_spec_real si = clean_stft_spec_imag yi = noisy_stft_spec_imag y_pow = noisy_power # (Sr * Yr + Si * Yi) / (Y_pow + 1e-8) gth_mask_real = (sr * yr + si * yi) / (y_pow + self.eps) # (Si * Yr - Sr * Yi) / (Y_pow + 1e-8) gth_mask_imag = (sr * yr - si * yi) / (y_pow + self.eps) gth_mask_real[gth_mask_real > 2] = 1 gth_mask_real[gth_mask_real < -2] = -1 gth_mask_imag[gth_mask_imag > 2] = 1 gth_mask_imag[gth_mask_imag < -2] = -1 amp_loss = F.mse_loss(gth_mask_real, mask_real) phase_loss = F.mse_loss(gth_mask_imag, mask_imag) loss = amp_loss + phase_loss return loss def main(): batch_size = 2 signal_length = 16000 estimated_signal = torch.randn(batch_size, signal_length) target_signal = torch.randn(batch_size, signal_length) loss_fn = CIRMLoss() loss = loss_fn.forward(estimated_signal, target_signal) print(f"loss: {loss.item()}") return if __name__ == "__main__": main()