Spaces:
Sleeping
Sleeping
File size: 5,512 Bytes
f872c8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import typing as tp
import torch
from .modules import SEANetDecoder
from .modules import SEANetEncoder
from .quantization import ResidualVectorQuantizer
################################################################################
# Encodec neural audio codec
################################################################################
class Encodec(torch.nn.Module):
"""
Encodec neural audio codec proposed in "High Fidelity Neural Audio
Compression" (https://arxiv.org/abs/2210.13438) by Défossez et al.
"""
def __init__(
self,
sample_rate: int,
channels: int,
causal: bool,
model_norm: str,
target_bandwidths: tp.Sequence[float],
audio_normalize: bool,
ratios: tp.List[int] = (8, 5, 4, 2),
codebook_size: int = 1024,
n_filters: int = 32,
true_skip: bool = False,
encoder_kwargs: tp.Dict = None,
decoder_kwargs: tp.Dict = None,
):
"""
Parameters
----------
sample_rate : int
Audio sample rate in Hz.
channels : int
Number of audio channels expected at input.
causal : bool
Whether to use a causal convolution layers in encoder/decoder.
model_norm : str
Type of normalization to use in encoder/decoder.
target_bandwidths : tp.Sequence[float]
List of target bandwidths in kb/s.
audio_normalize : bool
Whether to normalize encoded and decoded audio segments using
simple scaling factors
ratios : tp.List[int], optional
List of downsampling ratios used in encoder/decoder, by default (8, 5, 4, 2)
codebook_size : int, optional
Size of residual vector quantizer codebooks, by default 1024
n_filters : int, optional
Number of filters used in encoder/decoder, by default 32
true_skip : bool, optional
Whether to use true skip connections in encoder/decoder rather than
convolutional skip connections, by default False
"""
super().__init__()
encoder_kwargs = encoder_kwargs or {}
decoder_kwargs = decoder_kwargs or {}
self.encoder = SEANetEncoder(
channels=channels,
causal=causal,
norm=model_norm,
ratios=ratios,
n_filters=n_filters,
true_skip=true_skip,
**encoder_kwargs,
)
self.decoder = SEANetDecoder(
channels=channels,
causal=causal,
norm=model_norm,
ratios=ratios,
n_filters=n_filters,
true_skip=true_skip,
**decoder_kwargs,
)
n_q = int(
1000
* target_bandwidths[-1]
// (math.ceil(sample_rate / self.encoder.hop_length) * 10)
)
self.n_q = n_q # Maximum number of quantizers
self.quantizer = ResidualVectorQuantizer(
dimension=self.encoder.dimension,
n_q=n_q,
bins=codebook_size,
)
self.sample_rate = sample_rate
self.normalize = audio_normalize
self.channels = channels
self.frame_rate = math.ceil(self.sample_rate / math.prod(self.encoder.ratios))
self.target_bandwidths = target_bandwidths
self.bits_per_codebook = int(math.log2(self.quantizer.bins))
assert (
2**self.bits_per_codebook == self.quantizer.bins
), "quantizer bins must be a power of 2."
self.bandwidth = self.target_bandwidths[-1]
def set_target_bandwidth(self, bandwidth: float):
"""
Set the target bandwidth for the codec by adjusting the
number of residual vector quantizers used
"""
if bandwidth not in self.target_bandwidths:
raise ValueError(
f"This model doesn't support the bandwidth {bandwidth}. "
f"Select one of {self.target_bandwidths}."
)
self.bandwidth = bandwidth
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""
Map a given an audio waveform `x` to discrete residual latent codes.
Parameters
----------
x : torch.Tensor
Audio waveform of shape `(n_batch, n_channels, n_samples)`.
Returns
-------
codes : torch.Tensor
Tensor of shape `(n_batch, n_codebooks, n_frames)`.
"""
assert x.dim() == 3
_, channels, length = x.shape
assert 0 < channels <= 2
z = self.encoder(x)
codes, z_O, z_o = self.quantizer.encode(z, self.frame_rate, self.bandwidth)
codes = codes.transpose(0, 1)
return codes, z_O, z_o, z
def decode(self, codes: torch.Tensor):
"""
Decode quantized latents to obtain waveform audio.
Parameters
----------
codes : torch.Tensor
Tensor of shape `(n_batch, n_codebooks, n_frames)`.
Returns
-------
out : torch.Tensor
Tensor of shape `(n_batch, n_channels, n_samples)`.
"""
codes = codes.transpose(0, 1)
emb = self.quantizer.decode(codes)
out = self.decoder(emb)
return out
|