Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import fairseq | |
import torch as th | |
import torchaudio as ta | |
wav2vec_model_path = "./assets/wav2vec_large.pt" | |
def weights_init(m): | |
if isinstance(m, th.nn.Conv1d): | |
th.nn.init.xavier_uniform_(m.weight) | |
try: | |
th.nn.init.constant_(m.bias, 0.01) | |
except: | |
pass | |
class Wav2VecEncoder(th.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.resampler = ta.transforms.Resample(orig_freq=48000, new_freq=16000) | |
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( | |
[wav2vec_model_path] | |
) | |
self.wav2vec_model = model[0] | |
def forward(self, audio: th.Tensor): | |
""" | |
:param audio: B x T x 1600 | |
:return: B x T_wav2vec x 512 | |
""" | |
audio = audio.view(audio.shape[0], audio.shape[1] * 1600) | |
audio = self.resampler(audio) | |
audio = th.cat( | |
[th.zeros(audio.shape[0], 320, device=audio.device), audio], dim=-1 | |
) # zero padding on the left | |
x = self.wav2vec_model.feature_extractor(audio) | |
x = self.wav2vec_model.feature_aggregator(x) | |
x = x.permute(0, 2, 1).contiguous() | |
return x | |
class Wav2VecDownsampler(th.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv1 = th.nn.Conv1d(512, 512, kernel_size=3) | |
self.conv2 = th.nn.Conv1d(512, 512, kernel_size=3) | |
self.norm = th.nn.LayerNorm(512) | |
def forward(self, x: th.Tensor, target_length: int): | |
""" | |
:param x: B x T x 512 tensor containing wav2vec features at 100Hz | |
:return: B x target_length x 512 tensor containing downsampled wav2vec features at 30Hz | |
""" | |
x = x.permute(0, 2, 1).contiguous() | |
# first conv | |
x = th.nn.functional.pad(x, pad=(2, 0)) | |
x = th.nn.functional.relu(self.conv1(x)) | |
# first downsampling | |
x = th.nn.functional.interpolate(x, size=(x.shape[-1] + target_length) // 2) | |
# second conv | |
x = th.nn.functional.pad(x, pad=(2, 0)) | |
x = self.conv2(x) | |
# second downsampling | |
x = th.nn.functional.interpolate(x, size=target_length) | |
# layer norm | |
x = x.permute(0, 2, 1).contiguous() | |
x = self.norm(x) | |
return x | |
class AudioTcn(th.nn.Module): | |
def __init__( | |
self, | |
encoding_dim: int = 128, | |
use_melspec: bool = True, | |
use_wav2vec: bool = True, | |
): | |
""" | |
:param encoding_dim: size of encoding | |
:param use_melspec: extract mel spectrogram features as input | |
:param use_wav2vec: extract wav2vec features as input | |
""" | |
super().__init__() | |
self.encoding_dim = encoding_dim | |
self.use_melspec = use_melspec | |
self.use_wav2vec = use_wav2vec | |
if use_melspec: | |
# hop_length=400 -> two feature vectors per visual frame (downsampling to 24kHz -> 800 samples per frame) | |
self.melspec = th.nn.Sequential( | |
ta.transforms.Resample(orig_freq=48000, new_freq=24000), | |
ta.transforms.MelSpectrogram( | |
sample_rate=24000, | |
n_fft=1024, | |
win_length=800, | |
hop_length=400, | |
n_mels=80, | |
), | |
) | |
if use_wav2vec: | |
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( | |
[wav2vec_model_path] | |
) | |
self.wav2vec_model = model[0] | |
self.wav2vec_model.eval() | |
self.wav2vec_postprocess = th.nn.Conv1d(512, 256, kernel_size=3) | |
self.wav2vec_postprocess.apply(lambda x: weights_init(x)) | |
# temporal model | |
input_dim = 0 + (160 if use_melspec else 0) + (256 if use_wav2vec else 0) | |
self.layers = th.nn.ModuleList( | |
[ | |
th.nn.Conv1d( | |
input_dim, max(256, encoding_dim), kernel_size=3, dilation=1 | |
), # 2 (+1) | |
th.nn.Conv1d( | |
max(256, encoding_dim), encoding_dim, kernel_size=3, dilation=2 | |
), # 4 (+1) | |
th.nn.Conv1d( | |
encoding_dim, encoding_dim, kernel_size=3, dilation=3 | |
), # 6 (+1) | |
th.nn.Conv1d( | |
encoding_dim, encoding_dim, kernel_size=3, dilation=1 | |
), # 2 (+1) | |
th.nn.Conv1d( | |
encoding_dim, encoding_dim, kernel_size=3, dilation=2 | |
), # 4 (+1) | |
th.nn.Conv1d( | |
encoding_dim, encoding_dim, kernel_size=3, dilation=3 | |
), # 6 (+1) | |
] | |
) | |
self.layers.apply(lambda x: weights_init(x)) | |
self.receptive_field = 25 | |
self.final = th.nn.Conv1d(encoding_dim, encoding_dim, kernel_size=1) | |
self.final.apply(lambda x: weights_init(x)) | |
def forward(self, audio): | |
""" | |
:param audio: B x T x 1600 tensor containing audio samples for each frame | |
:return: B x T x encoding_dim tensor containing audio encodings for each frame | |
""" | |
B, T = audio.shape[0], audio.shape[1] | |
# preprocess raw audio signal to extract feature vectors | |
audio = audio.view(B, T * 1600) | |
x_mel, x_w2v = th.zeros(B, 0, T).to(audio.device), th.zeros(B, 0, T).to( | |
audio.device | |
) | |
if self.use_melspec: | |
x_mel = self.melspec(audio)[:, :, 1:].contiguous() | |
x_mel = th.log(x_mel.clamp(min=1e-10, max=None)) | |
x_mel = ( | |
x_mel.permute(0, 2, 1) | |
.contiguous() | |
.view(x_mel.shape[0], T, 160) | |
.permute(0, 2, 1) | |
.contiguous() | |
) | |
if self.use_wav2vec: | |
with th.no_grad(): | |
x_w2v = ta.functional.resample(audio, 48000, 16000) | |
x_w2v = self.wav2vec_model.feature_extractor(x_w2v) | |
x_w2v = self.wav2vec_model.feature_aggregator(x_w2v) | |
x_w2v = self.wav2vec_postprocess(th.nn.functional.pad(x_w2v, pad=[2, 0])) | |
x_w2v = th.nn.functional.interpolate( | |
x_w2v, size=T, align_corners=True, mode="linear" | |
) | |
x = th.cat([x_mel, x_w2v], dim=1) | |
# process signal with TCN | |
x = th.nn.functional.pad(x, pad=[self.receptive_field - 1, 0]) | |
for layer_idx, layer in enumerate(self.layers): | |
y = th.nn.functional.leaky_relu(layer(x), negative_slope=0.2) | |
if self.training: | |
y = th.nn.functional.dropout(y, 0.2) | |
if x.shape[1] == y.shape[1]: | |
x = (x[:, :, -y.shape[-1] :] + y) / 2.0 # skip connection | |
else: | |
x = y | |
x = self.final(x) | |
x = x.permute(0, 2, 1).contiguous() | |
return x | |