oreillyp's picture
initial commit
f872c8a
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import typing as tp
import torch
from .modules import SEANetDecoder
from .modules import SEANetEncoder
from .quantization import ResidualVectorQuantizer
################################################################################
# Encodec neural audio codec
################################################################################
class Encodec(torch.nn.Module):
"""
Encodec neural audio codec proposed in "High Fidelity Neural Audio
Compression" (https://arxiv.org/abs/2210.13438) by Défossez et al.
"""
def __init__(
self,
sample_rate: int,
channels: int,
causal: bool,
model_norm: str,
target_bandwidths: tp.Sequence[float],
audio_normalize: bool,
ratios: tp.List[int] = (8, 5, 4, 2),
codebook_size: int = 1024,
n_filters: int = 32,
true_skip: bool = False,
encoder_kwargs: tp.Dict = None,
decoder_kwargs: tp.Dict = None,
):
"""
Parameters
----------
sample_rate : int
Audio sample rate in Hz.
channels : int
Number of audio channels expected at input.
causal : bool
Whether to use a causal convolution layers in encoder/decoder.
model_norm : str
Type of normalization to use in encoder/decoder.
target_bandwidths : tp.Sequence[float]
List of target bandwidths in kb/s.
audio_normalize : bool
Whether to normalize encoded and decoded audio segments using
simple scaling factors
ratios : tp.List[int], optional
List of downsampling ratios used in encoder/decoder, by default (8, 5, 4, 2)
codebook_size : int, optional
Size of residual vector quantizer codebooks, by default 1024
n_filters : int, optional
Number of filters used in encoder/decoder, by default 32
true_skip : bool, optional
Whether to use true skip connections in encoder/decoder rather than
convolutional skip connections, by default False
"""
super().__init__()
encoder_kwargs = encoder_kwargs or {}
decoder_kwargs = decoder_kwargs or {}
self.encoder = SEANetEncoder(
channels=channels,
causal=causal,
norm=model_norm,
ratios=ratios,
n_filters=n_filters,
true_skip=true_skip,
**encoder_kwargs,
)
self.decoder = SEANetDecoder(
channels=channels,
causal=causal,
norm=model_norm,
ratios=ratios,
n_filters=n_filters,
true_skip=true_skip,
**decoder_kwargs,
)
n_q = int(
1000
* target_bandwidths[-1]
// (math.ceil(sample_rate / self.encoder.hop_length) * 10)
)
self.n_q = n_q # Maximum number of quantizers
self.quantizer = ResidualVectorQuantizer(
dimension=self.encoder.dimension,
n_q=n_q,
bins=codebook_size,
)
self.sample_rate = sample_rate
self.normalize = audio_normalize
self.channels = channels
self.frame_rate = math.ceil(self.sample_rate / math.prod(self.encoder.ratios))
self.target_bandwidths = target_bandwidths
self.bits_per_codebook = int(math.log2(self.quantizer.bins))
assert (
2**self.bits_per_codebook == self.quantizer.bins
), "quantizer bins must be a power of 2."
self.bandwidth = self.target_bandwidths[-1]
def set_target_bandwidth(self, bandwidth: float):
"""
Set the target bandwidth for the codec by adjusting the
number of residual vector quantizers used
"""
if bandwidth not in self.target_bandwidths:
raise ValueError(
f"This model doesn't support the bandwidth {bandwidth}. "
f"Select one of {self.target_bandwidths}."
)
self.bandwidth = bandwidth
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""
Map a given an audio waveform `x` to discrete residual latent codes.
Parameters
----------
x : torch.Tensor
Audio waveform of shape `(n_batch, n_channels, n_samples)`.
Returns
-------
codes : torch.Tensor
Tensor of shape `(n_batch, n_codebooks, n_frames)`.
"""
assert x.dim() == 3
_, channels, length = x.shape
assert 0 < channels <= 2
z = self.encoder(x)
codes, z_O, z_o = self.quantizer.encode(z, self.frame_rate, self.bandwidth)
codes = codes.transpose(0, 1)
return codes, z_O, z_o, z
def decode(self, codes: torch.Tensor):
"""
Decode quantized latents to obtain waveform audio.
Parameters
----------
codes : torch.Tensor
Tensor of shape `(n_batch, n_codebooks, n_frames)`.
Returns
-------
out : torch.Tensor
Tensor of shape `(n_batch, n_channels, n_samples)`.
"""
codes = codes.transpose(0, 1)
emb = self.quantizer.decode(codes)
out = self.decoder(emb)
return out