|
import torch
|
|
from scipy.signal import get_window
|
|
|
|
from torch import nn
|
|
|
|
'''
|
|
class LambdaOverlapAdd(torch.nn.Module):
|
|
"""Overlap-add with lambda transform on segments.
|
|
|
|
Segment input signal, apply lambda function (a neural network for example)
|
|
and combine with OLA.
|
|
|
|
Args:
|
|
nnet (callable): Function to apply to each segment.
|
|
n_src (int): Number of sources in the output of nnet.
|
|
window_size (int): Size of segmenting window.
|
|
hop_size (int): Segmentation hop size.
|
|
window (str): Name of the window (see scipy.signal.get_window) used
|
|
for the synthesis.
|
|
reorder_chunks (bool): Whether to reorder each consecutive segment.
|
|
This might be useful when `nnet` is permutation invariant, as
|
|
source assignements might change output channel from one segment
|
|
to the next (in classic speech separation for example).
|
|
Reordering is performed based on the correlation between
|
|
the overlapped part of consecutive segment.
|
|
|
|
Examples:
|
|
>>> from asteroid_test import ConvTasNet
|
|
>>> nnet = ConvTasNet(n_src=2)
|
|
>>> continuous_nnet = LambdaOverlapAdd(
|
|
>>> nnet=nnet,
|
|
>>> n_src=2,
|
|
>>> window_size=64000,
|
|
>>> hop_size=None,
|
|
>>> window="hanning",
|
|
>>> reorder_chunks=True,
|
|
>>> enable_grad=False,
|
|
>>> )
|
|
>>> wav = torch.randn(1, 1, 500000)
|
|
>>> out_wavs = continuous_nnet.forward(wav)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
nnet,
|
|
n_src,
|
|
window_size,
|
|
hop_size=None,
|
|
window="hanning",
|
|
reorder_chunks=True,
|
|
enable_grad=False,
|
|
):
|
|
super().__init__()
|
|
assert window_size % 2 == 0, "Window size must be even"
|
|
|
|
self.nnet = nnet
|
|
self.window_size = window_size
|
|
self.hop_size = hop_size if hop_size is not None else window_size // 2
|
|
self.n_src = n_src
|
|
|
|
if window:
|
|
window = get_window(window, self.window_size).astype("float32")
|
|
window = torch.from_numpy(window)
|
|
self.use_window = True
|
|
else:
|
|
self.use_window = False
|
|
|
|
self.register_buffer("window", window)
|
|
self.reorder_chunks = reorder_chunks
|
|
self.enable_grad = enable_grad
|
|
|
|
def ola_forward(self, x):
|
|
"""Heart of the class: segment signal, apply func, combine with OLA."""
|
|
|
|
assert x.ndim == 3
|
|
|
|
batch, channels, n_frames = x.size()
|
|
# Overlap and add:
|
|
# [batch, chans, n_frames] -> [batch, chans, win_size, n_chunks]
|
|
unfolded = torch.nn.functional.unfold(
|
|
x.unsqueeze(-1),
|
|
kernel_size=(self.window_size, 1),
|
|
padding=(self.window_size, 0),
|
|
stride=(self.hop_size, 1),
|
|
)
|
|
|
|
out = []
|
|
n_chunks = unfolded.shape[-1]
|
|
for frame_idx in range(n_chunks): # for loop to spare memory
|
|
frame = self.nnet(unfolded[..., frame_idx])
|
|
# user must handle multichannel by reshaping to batch
|
|
if frame_idx == 0:
|
|
assert frame.ndim == 3, "nnet should return (batch, n_src, time)"
|
|
assert frame.shape[1] == self.n_src, "nnet should return (batch, n_src, time)"
|
|
frame = frame.reshape(batch * self.n_src, -1)
|
|
|
|
if frame_idx != 0 and self.reorder_chunks:
|
|
# we determine best perm based on xcorr with previous sources
|
|
frame = _reorder_sources(
|
|
frame, out[-1], self.n_src, self.window_size, self.hop_size
|
|
)
|
|
|
|
if self.use_window:
|
|
frame = frame * self.window
|
|
else:
|
|
frame = frame / (self.window_size / self.hop_size)
|
|
out.append(frame)
|
|
|
|
out = torch.stack(out).reshape(n_chunks, batch * self.n_src, self.window_size)
|
|
out = out.permute(1, 2, 0)
|
|
|
|
out = torch.nn.functional.fold(
|
|
out,
|
|
(n_frames, 1),
|
|
kernel_size=(self.window_size, 1),
|
|
padding=(self.window_size, 0),
|
|
stride=(self.hop_size, 1),
|
|
)
|
|
return out.squeeze(-1).reshape(batch, self.n_src, -1)
|
|
|
|
def forward(self, x):
|
|
"""Forward module: segment signal, apply func, combine with OLA.
|
|
|
|
Args:
|
|
x (:class:`torch.Tensor`): waveform signal of shape (batch, 1, time).
|
|
|
|
Returns:
|
|
:class:`torch.Tensor`: The output of the lambda OLA.
|
|
"""
|
|
# Here we can do the reshaping
|
|
with torch.autograd.set_grad_enabled(self.enable_grad):
|
|
olad = self.ola_forward(x)
|
|
return olad
|
|
|
|
|
|
def _reorder_sources(
|
|
current: torch.FloatTensor,
|
|
previous: torch.FloatTensor,
|
|
n_src: int,
|
|
window_size: int,
|
|
hop_size: int,
|
|
):
|
|
"""
|
|
Reorder sources in current chunk to maximize correlation with previous chunk.
|
|
Used for Continuous Source Separation. Standard dsp correlation is used
|
|
for reordering.
|
|
|
|
|
|
Args:
|
|
current (:class:`torch.Tensor`): current chunk, tensor
|
|
of shape (batch, n_src, window_size)
|
|
previous (:class:`torch.Tensor`): previous chunk, tensor
|
|
of shape (batch, n_src, window_size)
|
|
n_src (:class:`int`): number of sources.
|
|
window_size (:class:`int`): window_size, equal to last dimension of
|
|
both current and previous.
|
|
hop_size (:class:`int`): hop_size between current and previous tensors.
|
|
|
|
Returns:
|
|
current:
|
|
|
|
"""
|
|
batch, frames = current.size()
|
|
current = current.reshape(-1, n_src, frames)
|
|
previous = previous.reshape(-1, n_src, frames)
|
|
|
|
overlap_f = window_size - hop_size
|
|
|
|
def reorder_func(x, y):
|
|
x = x[..., :overlap_f]
|
|
y = y[..., -overlap_f:]
|
|
# Mean normalization
|
|
x = x - x.mean(-1, keepdim=True)
|
|
y = y - y.mean(-1, keepdim=True)
|
|
# Negative mean Correlation
|
|
return -torch.sum(x.unsqueeze(1) * y.unsqueeze(2), dim=-1)
|
|
|
|
# We maximize correlation-like between previous and current.
|
|
pit = PITLossWrapper(reorder_func)
|
|
current = pit(current, previous, return_est=True)[1]
|
|
return current.reshape(batch, frames)
|
|
'''
|
|
|
|
|
|
class DualPathProcessing(nn.Module):
|
|
"""Perform Dual-Path processing via overlap-add as in DPRNN [1].
|
|
|
|
Args:
|
|
chunk_size (int): Size of segmenting window.
|
|
hop_size (int): segmentation hop size.
|
|
|
|
References:
|
|
[1] "Dual-path RNN: efficient long sequence modeling for
|
|
time-domain single-channel speech separation", Yi Luo, Zhuo Chen
|
|
and Takuya Yoshioka. https://arxiv.org/abs/1910.06379
|
|
"""
|
|
|
|
def __init__(self, chunk_size, hop_size):
|
|
super(DualPathProcessing, self).__init__()
|
|
self.chunk_size = chunk_size
|
|
self.hop_size = hop_size
|
|
self.n_orig_frames = None
|
|
|
|
def unfold(self, x):
|
|
"""Unfold the feature tensor from
|
|
|
|
(batch, channels, time) to (batch, channels, chunk_size, n_chunks).
|
|
|
|
Args:
|
|
x: (:class:`torch.Tensor`): feature tensor of shape (batch, channels, time).
|
|
|
|
Returns:
|
|
x: (:class:`torch.Tensor`): spliced feature tensor of shape
|
|
(batch, channels, chunk_size, n_chunks).
|
|
|
|
"""
|
|
|
|
batch, chan, frames = x.size()
|
|
assert x.ndim == 3
|
|
self.n_orig_frames = x.shape[-1]
|
|
unfolded = torch.nn.functional.unfold(
|
|
x.unsqueeze(-1),
|
|
kernel_size=(self.chunk_size, 1),
|
|
padding=(self.chunk_size, 0),
|
|
stride=(self.hop_size, 1),
|
|
)
|
|
|
|
return unfolded.reshape(
|
|
batch, chan, self.chunk_size, -1
|
|
)
|
|
|
|
def fold(self, x, output_size=None):
|
|
"""Folds back the spliced feature tensor.
|
|
|
|
Input shape (batch, channels, chunk_size, n_chunks) to original shape
|
|
(batch, channels, time) using overlap-add.
|
|
|
|
Args:
|
|
x: (:class:`torch.Tensor`): spliced feature tensor of shape
|
|
(batch, channels, chunk_size, n_chunks).
|
|
output_size: (int, optional): sequence length of original feature tensor.
|
|
If None, the original length cached by the previous call of `unfold`
|
|
will be used.
|
|
|
|
Returns:
|
|
x: (:class:`torch.Tensor`): feature tensor of shape (batch, channels, time).
|
|
|
|
.. note:: `fold` caches the original length of the pr
|
|
|
|
"""
|
|
output_size = output_size if output_size is not None else self.n_orig_frames
|
|
|
|
batch, chan, chunk_size, n_chunks = x.size()
|
|
to_unfold = x.reshape(batch, chan * self.chunk_size, n_chunks)
|
|
x = torch.nn.functional.fold(
|
|
to_unfold,
|
|
(output_size, 1),
|
|
kernel_size=(self.chunk_size, 1),
|
|
padding=(self.chunk_size, 0),
|
|
stride=(self.hop_size, 1),
|
|
)
|
|
|
|
x /= self.chunk_size / self.hop_size
|
|
|
|
return x.reshape(batch, chan, self.n_orig_frames)
|
|
|
|
@staticmethod
|
|
def intra_process(x, module):
|
|
"""Performs intra-chunk processing.
|
|
|
|
Args:
|
|
x (:class:`torch.Tensor`): spliced feature tensor of shape
|
|
(batch, channels, chunk_size, n_chunks).
|
|
module (:class:`torch.nn.Module`): module one wish to apply to each chunk
|
|
of the spliced feature tensor.
|
|
|
|
|
|
Returns:
|
|
x (:class:`torch.Tensor`): processed spliced feature tensor of shape
|
|
(batch, channels, chunk_size, n_chunks).
|
|
|
|
.. note:: the module should have the channel first convention and accept
|
|
a 3D tensor of shape (batch, channels, time).
|
|
"""
|
|
|
|
|
|
batch, channels, chunk_size, n_chunks = x.size()
|
|
|
|
x = x.transpose(1, -1).reshape(batch * n_chunks, chunk_size, channels).transpose(1, -1)
|
|
x = module(x)
|
|
x = x.reshape(batch, n_chunks, channels, chunk_size).transpose(1, -1).transpose(1, 2)
|
|
return x
|
|
|
|
@staticmethod
|
|
def inter_process(x, module):
|
|
"""Performs inter-chunk processing.
|
|
|
|
Args:
|
|
x (:class:`torch.Tensor`): spliced feature tensor of shape
|
|
(batch, channels, chunk_size, n_chunks).
|
|
module (:class:`torch.nn.Module`): module one wish to apply between
|
|
each chunk of the spliced feature tensor.
|
|
|
|
|
|
Returns:
|
|
x (:class:`torch.Tensor`): processed spliced feature tensor of shape
|
|
(batch, channels, chunk_size, n_chunks).
|
|
|
|
.. note:: the module should have the channel first convention and accept
|
|
a 3D tensor of shape (batch, channels, time).
|
|
"""
|
|
|
|
batch, channels, chunk_size, n_chunks = x.size()
|
|
x = x.transpose(1, 2).reshape(batch * chunk_size, channels, n_chunks)
|
|
x = module(x)
|
|
x = x.reshape(batch, chunk_size, channels, n_chunks).transpose(1, 2)
|
|
return x
|
|
|