File size: 5,512 Bytes
f872c8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import typing as tp

import torch

from .modules import SEANetDecoder
from .modules import SEANetEncoder
from .quantization import ResidualVectorQuantizer

################################################################################
# Encodec neural audio codec
################################################################################


class Encodec(torch.nn.Module):
    """
    Encodec neural audio codec proposed in "High Fidelity Neural Audio
    Compression" (https://arxiv.org/abs/2210.13438) by Défossez et al.
    """

    def __init__(
        self,
        sample_rate: int,
        channels: int,
        causal: bool,
        model_norm: str,
        target_bandwidths: tp.Sequence[float],
        audio_normalize: bool,
        ratios: tp.List[int] = (8, 5, 4, 2),
        codebook_size: int = 1024,
        n_filters: int = 32,
        true_skip: bool = False,
        encoder_kwargs: tp.Dict = None,
        decoder_kwargs: tp.Dict = None,
    ):
        """
        Parameters
        ----------
        sample_rate : int
            Audio sample rate in Hz.
        channels : int
            Number of audio channels expected at input.
        causal : bool
            Whether to use a causal convolution layers in encoder/decoder.
        model_norm : str
            Type of normalization to use in encoder/decoder.
        target_bandwidths : tp.Sequence[float]
            List of target bandwidths in kb/s.
        audio_normalize : bool
            Whether to normalize encoded and decoded audio segments using
            simple scaling factors
        ratios : tp.List[int], optional
            List of downsampling ratios used in encoder/decoder, by default (8, 5, 4, 2)
        codebook_size : int, optional
            Size of residual vector quantizer codebooks, by default 1024
        n_filters : int, optional
            Number of filters used in encoder/decoder, by default 32
        true_skip : bool, optional
            Whether to use true skip connections in encoder/decoder rather than
            convolutional skip connections, by default False
        """
        super().__init__()

        encoder_kwargs = encoder_kwargs or {}
        decoder_kwargs = decoder_kwargs or {}

        self.encoder = SEANetEncoder(
            channels=channels,
            causal=causal,
            norm=model_norm,
            ratios=ratios,
            n_filters=n_filters,
            true_skip=true_skip,
            **encoder_kwargs,
        )
        self.decoder = SEANetDecoder(
            channels=channels,
            causal=causal,
            norm=model_norm,
            ratios=ratios,
            n_filters=n_filters,
            true_skip=true_skip,
            **decoder_kwargs,
        )

        n_q = int(
            1000
            * target_bandwidths[-1]
            // (math.ceil(sample_rate / self.encoder.hop_length) * 10)
        )
        self.n_q = n_q  # Maximum number of quantizers
        self.quantizer = ResidualVectorQuantizer(
            dimension=self.encoder.dimension,
            n_q=n_q,
            bins=codebook_size,
        )

        self.sample_rate = sample_rate
        self.normalize = audio_normalize
        self.channels = channels

        self.frame_rate = math.ceil(self.sample_rate / math.prod(self.encoder.ratios))

        self.target_bandwidths = target_bandwidths
        self.bits_per_codebook = int(math.log2(self.quantizer.bins))
        assert (
            2**self.bits_per_codebook == self.quantizer.bins
        ), "quantizer bins must be a power of 2."

        self.bandwidth = self.target_bandwidths[-1]

    def set_target_bandwidth(self, bandwidth: float):
        """
        Set the target bandwidth for the codec by adjusting the
        number of residual vector quantizers used
        """
        if bandwidth not in self.target_bandwidths:
            raise ValueError(
                f"This model doesn't support the bandwidth {bandwidth}. "
                f"Select one of {self.target_bandwidths}."
            )
        self.bandwidth = bandwidth

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Map a given an audio waveform `x` to discrete residual latent codes.

        Parameters
        ----------
        x : torch.Tensor
            Audio waveform of shape `(n_batch, n_channels, n_samples)`.

        Returns
        -------
        codes : torch.Tensor
            Tensor of shape `(n_batch, n_codebooks, n_frames)`.
        """
        assert x.dim() == 3
        _, channels, length = x.shape
        assert 0 < channels <= 2

        z = self.encoder(x)
        codes, z_O, z_o = self.quantizer.encode(z, self.frame_rate, self.bandwidth)
        codes = codes.transpose(0, 1)

        return codes, z_O, z_o, z

    def decode(self, codes: torch.Tensor):
        """
        Decode quantized latents to obtain waveform audio.

        Parameters
        ----------
        codes : torch.Tensor
            Tensor of shape `(n_batch, n_codebooks, n_frames)`.

        Returns
        -------
        out : torch.Tensor
            Tensor of shape `(n_batch, n_channels, n_samples)`.
        """
        codes = codes.transpose(0, 1)
        emb = self.quantizer.decode(codes)
        out = self.decoder(emb)

        return out