File size: 5,296 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
from typing import Tuple
from pytorch_wpe import wpe_one_iteration
import torch
from torch_complex.tensor import ComplexTensor
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet2.enh.layers.mask_estimator import MaskEstimator
class DNN_WPE(torch.nn.Module):
def __init__(
self,
wtype: str = "blstmp",
widim: int = 257,
wlayers: int = 3,
wunits: int = 300,
wprojs: int = 320,
dropout_rate: float = 0.0,
taps: int = 5,
delay: int = 3,
use_dnn_mask: bool = True,
nmask: int = 1,
nonlinear: str = "sigmoid",
iterations: int = 1,
normalization: bool = False,
eps: float = 1e-6,
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
mask_flooring: bool = False,
flooring_thres: float = 1e-6,
use_torch_solver: bool = True,
):
super().__init__()
self.iterations = iterations
self.taps = taps
self.delay = delay
self.eps = eps
self.normalization = normalization
self.use_dnn_mask = use_dnn_mask
self.inverse_power = True
self.diagonal_loading = diagonal_loading
self.diag_eps = diag_eps
self.mask_flooring = mask_flooring
self.flooring_thres = flooring_thres
self.use_torch_solver = use_torch_solver
if self.use_dnn_mask:
self.nmask = nmask
self.mask_est = MaskEstimator(
wtype,
widim,
wlayers,
wunits,
wprojs,
dropout_rate,
nmask=nmask,
nonlinear=nonlinear,
)
else:
self.nmask = 1
def forward(
self, data: ComplexTensor, ilens: torch.LongTensor
) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
"""DNN_WPE forward function.
Notation:
B: Batch
C: Channel
T: Time or Sequence length
F: Freq or Some dimension of the feature vector
Args:
data: (B, T, C, F)
ilens: (B,)
Returns:
enhanced (torch.Tensor or List[torch.Tensor]): (B, T, C, F)
ilens: (B,)
masks (torch.Tensor or List[torch.Tensor]): (B, T, C, F)
power (List[torch.Tensor]): (B, F, T)
"""
# (B, T, C, F) -> (B, F, C, T)
data = data.permute(0, 3, 2, 1)
enhanced = [data for i in range(self.nmask)]
masks = None
power = None
for i in range(self.iterations):
# Calculate power: (..., C, T)
power = [enh.real ** 2 + enh.imag ** 2 for enh in enhanced]
if i == 0 and self.use_dnn_mask:
# mask: (B, F, C, T)
masks, _ = self.mask_est(data, ilens)
# floor masks to increase numerical stability
if self.mask_flooring:
masks = [m.clamp(min=self.flooring_thres) for m in masks]
if self.normalization:
# Normalize along T
masks = [m / m.sum(dim=-1, keepdim=True) for m in masks]
# (..., C, T) * (..., C, T) -> (..., C, T)
power = [p * masks[i] for i, p in enumerate(power)]
# Averaging along the channel axis: (..., C, T) -> (..., T)
power = [p.mean(dim=-2).clamp(min=self.eps) for p in power]
# enhanced: (..., C, T) -> (..., C, T)
# NOTE(kamo): Calculate in double precision
enhanced = [
wpe_one_iteration(
data.contiguous().double(),
p.double(),
taps=self.taps,
delay=self.delay,
inverse_power=self.inverse_power,
)
for p in power
]
enhanced = [
enh.to(dtype=data.dtype).masked_fill(make_pad_mask(ilens, enh.real), 0)
for enh in enhanced
]
# (B, F, C, T) -> (B, T, C, F)
enhanced = [enh.permute(0, 3, 2, 1) for enh in enhanced]
if masks is not None:
masks = (
[m.transpose(-1, -3) for m in masks]
if self.nmask > 1
else masks[0].transpose(-1, -3)
)
if self.nmask == 1:
enhanced = enhanced[0]
return enhanced, ilens, masks, power
def predict_mask(
self, data: ComplexTensor, ilens: torch.LongTensor
) -> Tuple[torch.Tensor, torch.LongTensor]:
"""Predict mask for WPE dereverberation.
Args:
data (ComplexTensor): (B, T, C, F), double precision
ilens (torch.Tensor): (B,)
Returns:
masks (torch.Tensor or List[torch.Tensor]): (B, T, C, F)
ilens (torch.Tensor): (B,)
"""
if self.use_dnn_mask:
masks, ilens = self.mask_est(data.permute(0, 3, 2, 1).float(), ilens)
# (B, F, C, T) -> (B, T, C, F)
masks = [m.transpose(-1, -3) for m in masks]
if self.nmask == 1:
masks = masks[0]
else:
masks = None
return masks, ilens
|