File size: 5,207 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
from collections import OrderedDict
from typing import Tuple
import warnings
import torch
from espnet2.enh.separator.abs_separator import AbsSeparator
class AsteroidModel_Converter(AbsSeparator):
def __init__(
self,
encoder_output_dim: int,
model_name: str,
num_spk: int,
pretrained_path: str = "",
loss_type: str = "si_snr",
**model_related_kwargs,
):
"""The class to convert the models from asteroid to AbsSeprator.
Args:
encoder_output_dim: input feature dimension, deafult=1 after the NullEncoder
num_spk: number of speakers
loss_type: loss type of enhancement
model_name: Asteroid model names, e.g. ConvTasNet, DPTNet. Refers to
https://github.com/asteroid-team/asteroid/
blob/master/asteroid/models/__init__.py
pretrained_path: the name of pretrained model from Asteroid in HF hub.
Refers to: https://github.com/asteroid-team/asteroid/
blob/master/docs/source/readmes/pretrained_models.md and
https://huggingface.co/models?filter=asteroid
model_related_kwargs: more args towards each specific asteroid model.
"""
super(AsteroidModel_Converter, self).__init__()
assert (
encoder_output_dim == 1
), encoder_output_dim # The input should in raw-wave domain.
# Please make sure the installation of Asteroid.
# https://github.com/asteroid-team/asteroid
from asteroid import models
model_related_kwargs = {
k: None if v == "None" else v for k, v in model_related_kwargs.items()
}
# print('args:',model_related_kwargs)
if pretrained_path:
model = getattr(models, model_name).from_pretrained(pretrained_path)
print("model_kwargs:", model_related_kwargs)
if model_related_kwargs:
warnings.warn(
"Pratrained model should get no args with %s" % model_related_kwargs
)
else:
model_name = getattr(models, model_name)
model = model_name(**model_related_kwargs)
self.model = model
self._num_spk = num_spk
self.loss_type = loss_type
if loss_type != "si_snr":
raise ValueError("Unsupported loss type: %s" % loss_type)
def forward(self, input: torch.Tensor, ilens: torch.Tensor = None):
"""Whole forward of asteroid models.
Args:
input (torch.Tensor): Raw Waveforms [B, T]
ilens (torch.Tensor): input lengths [B]
Returns:
estimated Waveforms(List[Union(torch.Tensor]): [(B, T), ...]
ilens (torch.Tensor): (B,)
others predicted data, e.g. masks: OrderedDict[
'mask_spk1': torch.Tensor(Batch, T),
'mask_spk2': torch.Tensor(Batch, T),
...
'mask_spkn': torch.Tensor(Batch, T),
]
"""
if hasattr(self.model, "forward_wav"):
est_source = self.model.forward_wav(input) # B,nspk,T or nspk,T
else:
est_source = self.model(input) # B,nspk,T or nspk,T
if input.dim() == 1:
assert est_source.size(0) == self.num_spk, est_source.size(0)
else:
assert est_source.size(1) == self.num_spk, est_source.size(1)
est_source = [es for es in est_source.transpose(0, 1)] # List(M,T)
masks = OrderedDict(
zip(["mask_spk{}".format(i + 1) for i in range(self.num_spk)], est_source)
)
return est_source, ilens, masks
def forward_rawwav(
self, input: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Output with waveforms. """
return self.forward(input, ilens)
@property
def num_spk(self):
return self._num_spk
if __name__ == "__main__":
mixture = torch.randn(3, 16000)
print("mixture shape", mixture.shape)
net = AsteroidModel_Converter(
model_name="ConvTasNet",
encoder_output_dim=1,
num_spk=2,
loss_type="si_snr",
pretrained_path="mpariente/ConvTasNet_WHAM!_sepclean",
)
print("model", net)
output, *__ = net(mixture)
output, *__ = net.forward_rawwav(mixture, 111)
print("output spk1 shape", output[0].shape)
net = AsteroidModel_Converter(
encoder_output_dim=1,
num_spk=2,
model_name="ConvTasNet",
n_src=2,
loss_type="si_snr",
out_chan=None,
n_blocks=2,
n_repeats=2,
bn_chan=128,
hid_chan=512,
skip_chan=128,
conv_kernel_size=3,
norm_type="gLN",
mask_act="sigmoid",
in_chan=None,
fb_name="free",
kernel_size=16,
n_filters=512,
stride=8,
encoder_activation=None,
sample_rate=8000,
)
print("\n\nmodel", net)
output, *__ = net(mixture)
print("output spk1 shape", output[0].shape)
print("Finished", output[0].shape)
|