Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
import typing as tp | |
import torch | |
from .modules import SEANetDecoder | |
from .modules import SEANetEncoder | |
from .quantization import ResidualVectorQuantizer | |
################################################################################ | |
# Encodec neural audio codec | |
################################################################################ | |
class Encodec(torch.nn.Module): | |
""" | |
Encodec neural audio codec proposed in "High Fidelity Neural Audio | |
Compression" (https://arxiv.org/abs/2210.13438) by Défossez et al. | |
""" | |
def __init__( | |
self, | |
sample_rate: int, | |
channels: int, | |
causal: bool, | |
model_norm: str, | |
target_bandwidths: tp.Sequence[float], | |
audio_normalize: bool, | |
ratios: tp.List[int] = (8, 5, 4, 2), | |
codebook_size: int = 1024, | |
n_filters: int = 32, | |
true_skip: bool = False, | |
encoder_kwargs: tp.Dict = None, | |
decoder_kwargs: tp.Dict = None, | |
): | |
""" | |
Parameters | |
---------- | |
sample_rate : int | |
Audio sample rate in Hz. | |
channels : int | |
Number of audio channels expected at input. | |
causal : bool | |
Whether to use a causal convolution layers in encoder/decoder. | |
model_norm : str | |
Type of normalization to use in encoder/decoder. | |
target_bandwidths : tp.Sequence[float] | |
List of target bandwidths in kb/s. | |
audio_normalize : bool | |
Whether to normalize encoded and decoded audio segments using | |
simple scaling factors | |
ratios : tp.List[int], optional | |
List of downsampling ratios used in encoder/decoder, by default (8, 5, 4, 2) | |
codebook_size : int, optional | |
Size of residual vector quantizer codebooks, by default 1024 | |
n_filters : int, optional | |
Number of filters used in encoder/decoder, by default 32 | |
true_skip : bool, optional | |
Whether to use true skip connections in encoder/decoder rather than | |
convolutional skip connections, by default False | |
""" | |
super().__init__() | |
encoder_kwargs = encoder_kwargs or {} | |
decoder_kwargs = decoder_kwargs or {} | |
self.encoder = SEANetEncoder( | |
channels=channels, | |
causal=causal, | |
norm=model_norm, | |
ratios=ratios, | |
n_filters=n_filters, | |
true_skip=true_skip, | |
**encoder_kwargs, | |
) | |
self.decoder = SEANetDecoder( | |
channels=channels, | |
causal=causal, | |
norm=model_norm, | |
ratios=ratios, | |
n_filters=n_filters, | |
true_skip=true_skip, | |
**decoder_kwargs, | |
) | |
n_q = int( | |
1000 | |
* target_bandwidths[-1] | |
// (math.ceil(sample_rate / self.encoder.hop_length) * 10) | |
) | |
self.n_q = n_q # Maximum number of quantizers | |
self.quantizer = ResidualVectorQuantizer( | |
dimension=self.encoder.dimension, | |
n_q=n_q, | |
bins=codebook_size, | |
) | |
self.sample_rate = sample_rate | |
self.normalize = audio_normalize | |
self.channels = channels | |
self.frame_rate = math.ceil(self.sample_rate / math.prod(self.encoder.ratios)) | |
self.target_bandwidths = target_bandwidths | |
self.bits_per_codebook = int(math.log2(self.quantizer.bins)) | |
assert ( | |
2**self.bits_per_codebook == self.quantizer.bins | |
), "quantizer bins must be a power of 2." | |
self.bandwidth = self.target_bandwidths[-1] | |
def set_target_bandwidth(self, bandwidth: float): | |
""" | |
Set the target bandwidth for the codec by adjusting the | |
number of residual vector quantizers used | |
""" | |
if bandwidth not in self.target_bandwidths: | |
raise ValueError( | |
f"This model doesn't support the bandwidth {bandwidth}. " | |
f"Select one of {self.target_bandwidths}." | |
) | |
self.bandwidth = bandwidth | |
def encode(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Map a given an audio waveform `x` to discrete residual latent codes. | |
Parameters | |
---------- | |
x : torch.Tensor | |
Audio waveform of shape `(n_batch, n_channels, n_samples)`. | |
Returns | |
------- | |
codes : torch.Tensor | |
Tensor of shape `(n_batch, n_codebooks, n_frames)`. | |
""" | |
assert x.dim() == 3 | |
_, channels, length = x.shape | |
assert 0 < channels <= 2 | |
z = self.encoder(x) | |
codes, z_O, z_o = self.quantizer.encode(z, self.frame_rate, self.bandwidth) | |
codes = codes.transpose(0, 1) | |
return codes, z_O, z_o, z | |
def decode(self, codes: torch.Tensor): | |
""" | |
Decode quantized latents to obtain waveform audio. | |
Parameters | |
---------- | |
codes : torch.Tensor | |
Tensor of shape `(n_batch, n_codebooks, n_frames)`. | |
Returns | |
------- | |
out : torch.Tensor | |
Tensor of shape `(n_batch, n_channels, n_samples)`. | |
""" | |
codes = codes.transpose(0, 1) | |
emb = self.quantizer.decode(codes) | |
out = self.decoder(emb) | |
return out | |